Skip to content

Commit 872df12

Browse files
authored
Add deferrable capability to existing DataprocDeleteClusterOperator (#29349)
* Add deferrable capability to existing DataprocDeleteClusterOperator Using param deferrable=True, add support for deleting a Google Dataproc cluster asynchronously.
1 parent d37ef06 commit 872df12

File tree

5 files changed

+190
-4
lines changed

5 files changed

+190
-4
lines changed

β€Žairflow/providers/google/cloud/operators/dataproc.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from airflow.providers.google.cloud.triggers.dataproc import (
5454
DataprocBatchTrigger,
5555
DataprocClusterTrigger,
56+
DataprocDeleteClusterTrigger,
5657
DataprocSubmitTrigger,
5758
)
5859
from airflow.utils import timezone
@@ -822,6 +823,8 @@ class DataprocDeleteClusterOperator(BaseOperator):
822823
If set as a sequence, the identities from the list must grant
823824
Service Account Token Creator IAM role to the directly preceding identity, with first
824825
account from the list granting this role to the originating account (templated).
826+
:param deferrable: Run operator in the deferrable mode.
827+
:param polling_interval_seconds: Time (seconds) to wait between calls to check the cluster status.
825828
"""
826829

827830
template_fields: Sequence[str] = ("project_id", "region", "cluster_name", "impersonation_chain")
@@ -835,13 +838,17 @@ def __init__(
835838
cluster_uuid: str | None = None,
836839
request_id: str | None = None,
837840
retry: Retry | _MethodDefault = DEFAULT,
838-
timeout: float | None = None,
841+
timeout: float = 1 * 60 * 60,
839842
metadata: Sequence[tuple[str, str]] = (),
840843
gcp_conn_id: str = "google_cloud_default",
841844
impersonation_chain: str | Sequence[str] | None = None,
845+
deferrable: bool = False,
846+
polling_interval_seconds: int = 10,
842847
**kwargs,
843848
):
844849
super().__init__(**kwargs)
850+
if deferrable and polling_interval_seconds <= 0:
851+
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
845852
self.project_id = project_id
846853
self.region = region
847854
self.cluster_name = cluster_name
@@ -852,11 +859,48 @@ def __init__(
852859
self.metadata = metadata
853860
self.gcp_conn_id = gcp_conn_id
854861
self.impersonation_chain = impersonation_chain
862+
self.deferrable = deferrable
863+
self.polling_interval_seconds = polling_interval_seconds
855864

856865
def execute(self, context: Context) -> None:
857866
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
867+
operation = self._delete_cluster(hook)
868+
if not self.deferrable:
869+
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
870+
self.log.info("Cluster deleted.")
871+
else:
872+
end_time: float = time.time() + self.timeout
873+
self.defer(
874+
trigger=DataprocDeleteClusterTrigger(
875+
gcp_conn_id=self.gcp_conn_id,
876+
project_id=self.project_id,
877+
region=self.region,
878+
cluster_name=self.cluster_name,
879+
request_id=self.request_id,
880+
retry=self.retry,
881+
end_time=end_time,
882+
metadata=self.metadata,
883+
impersonation_chain=self.impersonation_chain,
884+
polling_interval=self.polling_interval_seconds,
885+
),
886+
method_name="execute_complete",
887+
)
888+
889+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> Any:
890+
"""
891+
Callback for when the trigger fires - returns immediately.
892+
Relies on trigger to throw an exception, otherwise it assumes execution was
893+
successful.
894+
"""
895+
if event and event["status"] == "error":
896+
raise AirflowException(event["message"])
897+
elif event is None:
898+
raise AirflowException("No event received in trigger callback")
899+
self.log.info("Cluster deleted.")
900+
901+
def _delete_cluster(self, hook: DataprocHook):
858902
self.log.info("Deleting cluster: %s", self.cluster_name)
859-
operation = hook.delete_cluster(
903+
return hook.delete_cluster(
860904
project_id=self.project_id,
861905
region=self.region,
862906
cluster_name=self.cluster_name,
@@ -866,8 +910,6 @@ def execute(self, context: Context) -> None:
866910
timeout=self.timeout,
867911
metadata=self.metadata,
868912
)
869-
operation.result()
870-
self.log.info("Cluster deleted.")
871913

872914

873915
class DataprocJobBaseOperator(BaseOperator):

β€Žairflow/providers/google/cloud/triggers/dataproc.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
from __future__ import annotations
2020

2121
import asyncio
22+
import time
2223
import warnings
2324
from typing import Any, AsyncIterator, Sequence
2425

26+
from google.api_core.exceptions import NotFound
2527
from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
2628

2729
from airflow import AirflowException
@@ -219,3 +221,92 @@ async def run(self):
219221
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
220222
await asyncio.sleep(self.polling_interval_seconds)
221223
yield TriggerEvent({"batch_id": self.batch_id, "batch_state": state})
224+
225+
226+
class DataprocDeleteClusterTrigger(BaseTrigger):
227+
"""
228+
Asynchronously checks the status of a cluster.
229+
230+
:param cluster_name: The name of the cluster
231+
:param end_time: Time in second left to check the cluster status
232+
:param project_id: The ID of the Google Cloud project the cluster belongs to
233+
:param region: The Cloud Dataproc region in which to handle the request
234+
:param metadata: Additional metadata that is provided to the method
235+
:param gcp_conn_id: The connection ID to use when fetching connection info.
236+
:param impersonation_chain: Optional service account to impersonate using short-term
237+
credentials, or chained list of accounts required to get the access_token
238+
of the last account in the list, which will be impersonated in the request.
239+
If set as a string, the account must grant the originating account
240+
the Service Account Token Creator IAM role.
241+
If set as a sequence, the identities from the list must grant
242+
Service Account Token Creator IAM role to the directly preceding identity, with first
243+
account from the list granting this role to the originating account.
244+
:param polling_interval: Time in seconds to sleep between checks of cluster status
245+
"""
246+
247+
def __init__(
248+
self,
249+
cluster_name: str,
250+
end_time: float,
251+
project_id: str | None = None,
252+
region: str | None = None,
253+
metadata: Sequence[tuple[str, str]] = (),
254+
gcp_conn_id: str = "google_cloud_default",
255+
impersonation_chain: str | Sequence[str] | None = None,
256+
polling_interval: float = 5.0,
257+
**kwargs: Any,
258+
):
259+
super().__init__(**kwargs)
260+
self.cluster_name = cluster_name
261+
self.end_time = end_time
262+
self.project_id = project_id
263+
self.region = region
264+
self.metadata = metadata
265+
self.gcp_conn_id = gcp_conn_id
266+
self.impersonation_chain = impersonation_chain
267+
self.polling_interval = polling_interval
268+
269+
def serialize(self) -> tuple[str, dict[str, Any]]:
270+
"""Serializes DataprocDeleteClusterTrigger arguments and classpath."""
271+
return (
272+
"airflow.providers.google.cloud.triggers.dataproc.DataprocDeleteClusterTrigger",
273+
{
274+
"cluster_name": self.cluster_name,
275+
"end_time": self.end_time,
276+
"project_id": self.project_id,
277+
"region": self.region,
278+
"metadata": self.metadata,
279+
"gcp_conn_id": self.gcp_conn_id,
280+
"impersonation_chain": self.impersonation_chain,
281+
"polling_interval": self.polling_interval,
282+
},
283+
)
284+
285+
async def run(self) -> AsyncIterator["TriggerEvent"]:
286+
"""Wait until cluster is deleted completely"""
287+
hook = self._get_hook()
288+
while self.end_time > time.time():
289+
try:
290+
cluster = await hook.get_cluster(
291+
region=self.region, # type: ignore[arg-type]
292+
cluster_name=self.cluster_name,
293+
project_id=self.project_id, # type: ignore[arg-type]
294+
metadata=self.metadata,
295+
)
296+
self.log.info(
297+
"Cluster status is %s. Sleeping for %s seconds.",
298+
cluster.status.state,
299+
self.polling_interval,
300+
)
301+
await asyncio.sleep(self.polling_interval)
302+
except NotFound:
303+
yield TriggerEvent({"status": "success", "message": ""})
304+
except Exception as e:
305+
yield TriggerEvent({"status": "error", "message": str(e)})
306+
yield TriggerEvent({"status": "error", "message": "Timeout"})
307+
308+
def _get_hook(self) -> DataprocAsyncHook:
309+
return DataprocAsyncHook(
310+
gcp_conn_id=self.gcp_conn_id,
311+
impersonation_chain=self.impersonation_chain,
312+
)

β€Ždocs/apache-airflow-providers-google/operators/cloud/dataproc.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,14 @@ To delete a cluster you can use:
140140
:start-after: [START how_to_cloud_dataproc_delete_cluster_operator]
141141
:end-before: [END how_to_cloud_dataproc_delete_cluster_operator]
142142

143+
You can use deferrable mode for this action in order to run the operator asynchronously:
144+
145+
.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
146+
:language: python
147+
:dedent: 4
148+
:start-after: [START how_to_cloud_dataproc_delete_cluster_operator_async]
149+
:end-before: [END how_to_cloud_dataproc_delete_cluster_operator_async]
150+
143151
Submit a job to a cluster
144152
-------------------------
145153

β€Žtests/providers/google/cloud/operators/test_dataproc.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from airflow.providers.google.cloud.triggers.dataproc import (
5858
DataprocBatchTrigger,
5959
DataprocClusterTrigger,
60+
DataprocDeleteClusterTrigger,
6061
DataprocSubmitTrigger,
6162
)
6263
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
@@ -875,6 +876,47 @@ def test_execute(self, mock_hook):
875876
metadata=METADATA,
876877
)
877878

879+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
880+
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
881+
def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
882+
mock_hook.return_value.create_cluster.return_value = None
883+
operator = DataprocDeleteClusterOperator(
884+
task_id=TASK_ID,
885+
region=GCP_REGION,
886+
project_id=GCP_PROJECT,
887+
cluster_name=CLUSTER_NAME,
888+
request_id=REQUEST_ID,
889+
gcp_conn_id=GCP_CONN_ID,
890+
retry=RETRY,
891+
timeout=TIMEOUT,
892+
metadata=METADATA,
893+
impersonation_chain=IMPERSONATION_CHAIN,
894+
deferrable=True,
895+
)
896+
897+
with pytest.raises(TaskDeferred) as exc:
898+
operator.execute(mock.MagicMock())
899+
900+
mock_hook.assert_called_once_with(
901+
gcp_conn_id=GCP_CONN_ID,
902+
impersonation_chain=IMPERSONATION_CHAIN,
903+
)
904+
905+
mock_hook.return_value.delete_cluster.assert_called_once_with(
906+
project_id=GCP_PROJECT,
907+
region=GCP_REGION,
908+
cluster_name=CLUSTER_NAME,
909+
cluster_uuid=None,
910+
request_id=REQUEST_ID,
911+
retry=RETRY,
912+
timeout=TIMEOUT,
913+
metadata=METADATA,
914+
)
915+
916+
mock_hook.return_value.wait_for_operation.assert_not_called()
917+
assert isinstance(exc.value.trigger, DataprocDeleteClusterTrigger)
918+
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
919+
878920

879921
class TestDataprocSubmitJobOperator(DataprocJobTestBase):
880922
@mock.patch(DATAPROC_PATH.format("DataprocHook"))

β€Žtests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,16 @@
9898
)
9999
# [END how_to_cloud_dataproc_update_cluster_operator_async]
100100

101+
# [START how_to_cloud_dataproc_delete_cluster_operator_async]
101102
delete_cluster = DataprocDeleteClusterOperator(
102103
task_id="delete_cluster",
103104
project_id=PROJECT_ID,
104105
cluster_name=CLUSTER_NAME,
105106
region=REGION,
106107
trigger_rule=TriggerRule.ALL_DONE,
108+
deferrable=True,
107109
)
110+
# [END how_to_cloud_dataproc_delete_cluster_operator_async]
108111

109112
create_cluster >> update_cluster >> delete_cluster
110113

0 commit comments

Comments
 (0)