Skip to content

Commit 9fd8013

Browse files
bkossakowskaBeata Kossakowska
andauthored
Add deferrable mode to DataprocCreateClusterOperator and DataprocUpdateClusterOperator (#28529)
Co-authored-by: Beata Kossakowska <bkossakowska@google.com>
1 parent b8f15a9 commit 9fd8013

File tree

6 files changed

+515
-22
lines changed

6 files changed

+515
-22
lines changed

β€Žairflow/providers/google/cloud/operators/dataproc.py

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
import uuid
2727
import warnings
2828
from datetime import datetime, timedelta
29-
from typing import TYPE_CHECKING, Sequence
29+
from typing import TYPE_CHECKING, Any, Sequence
3030

3131
from google.api_core import operation # type: ignore
3232
from google.api_core.exceptions import AlreadyExists, NotFound
3333
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
3434
from google.api_core.retry import Retry, exponential_sleep_generator
35-
from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus
35+
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
3636
from google.protobuf.duration_pb2 import Duration
3737
from google.protobuf.field_mask_pb2 import FieldMask
3838

@@ -50,7 +50,7 @@
5050
DataprocLink,
5151
DataprocListLink,
5252
)
53-
from airflow.providers.google.cloud.triggers.dataproc import DataprocBaseTrigger
53+
from airflow.providers.google.cloud.triggers.dataproc import DataprocClusterTrigger, DataprocSubmitTrigger
5454
from airflow.utils import timezone
5555

5656
if TYPE_CHECKING:
@@ -438,6 +438,8 @@ class DataprocCreateClusterOperator(BaseOperator):
438438
If set as a sequence, the identities from the list must grant
439439
Service Account Token Creator IAM role to the directly preceding identity, with first
440440
account from the list granting this role to the originating account (templated).
441+
:param deferrable: Run operator in the deferrable mode.
442+
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
441443
"""
442444

443445
template_fields: Sequence[str] = (
@@ -470,6 +472,8 @@ def __init__(
470472
metadata: Sequence[tuple[str, str]] = (),
471473
gcp_conn_id: str = "google_cloud_default",
472474
impersonation_chain: str | Sequence[str] | None = None,
475+
deferrable: bool = False,
476+
polling_interval_seconds: int = 10,
473477
**kwargs,
474478
) -> None:
475479

@@ -502,7 +506,8 @@ def __init__(
502506
del kwargs[arg]
503507

504508
super().__init__(**kwargs)
505-
509+
if deferrable and polling_interval_seconds <= 0:
510+
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
506511
self.cluster_config = cluster_config
507512
self.cluster_name = cluster_name
508513
self.labels = labels
@@ -517,9 +522,11 @@ def __init__(
517522
self.use_if_exists = use_if_exists
518523
self.impersonation_chain = impersonation_chain
519524
self.virtual_cluster_config = virtual_cluster_config
525+
self.deferrable = deferrable
526+
self.polling_interval_seconds = polling_interval_seconds
520527

521528
def _create_cluster(self, hook: DataprocHook):
522-
operation = hook.create_cluster(
529+
return hook.create_cluster(
523530
project_id=self.project_id,
524531
region=self.region,
525532
cluster_name=self.cluster_name,
@@ -531,9 +538,6 @@ def _create_cluster(self, hook: DataprocHook):
531538
timeout=self.timeout,
532539
metadata=self.metadata,
533540
)
534-
cluster = operation.result()
535-
self.log.info("Cluster created.")
536-
return cluster
537541

538542
def _delete_cluster(self, hook):
539543
self.log.info("Deleting the cluster")
@@ -596,7 +600,25 @@ def execute(self, context: Context) -> dict:
596600
)
597601
try:
598602
# First try to create a new cluster
599-
cluster = self._create_cluster(hook)
603+
operation = self._create_cluster(hook)
604+
if not self.deferrable:
605+
cluster = hook.wait_for_operation(
606+
timeout=self.timeout, result_retry=self.retry, operation=operation
607+
)
608+
self.log.info("Cluster created.")
609+
return Cluster.to_dict(cluster)
610+
else:
611+
self.defer(
612+
trigger=DataprocClusterTrigger(
613+
cluster_name=self.cluster_name,
614+
project_id=self.project_id,
615+
region=self.region,
616+
gcp_conn_id=self.gcp_conn_id,
617+
impersonation_chain=self.impersonation_chain,
618+
polling_interval_seconds=self.polling_interval_seconds,
619+
),
620+
method_name="execute_complete",
621+
)
600622
except AlreadyExists:
601623
if not self.use_if_exists:
602624
raise
@@ -618,6 +640,21 @@ def execute(self, context: Context) -> dict:
618640

619641
return Cluster.to_dict(cluster)
620642

643+
def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
644+
"""
645+
Callback for when the trigger fires - returns immediately.
646+
Relies on trigger to throw an exception, otherwise it assumes execution was
647+
successful.
648+
"""
649+
cluster_state = event["cluster_state"]
650+
cluster_name = event["cluster_name"]
651+
652+
if cluster_state == ClusterStatus.State.ERROR:
653+
raise AirflowException(f"Cluster is in ERROR state:\n{cluster_name}")
654+
655+
self.log.info("%s completed successfully.", self.task_id)
656+
return event["cluster"]
657+
621658

622659
class DataprocScaleClusterOperator(BaseOperator):
623660
"""
@@ -974,7 +1011,7 @@ def execute(self, context: Context):
9741011

9751012
if self.deferrable:
9761013
self.defer(
977-
trigger=DataprocBaseTrigger(
1014+
trigger=DataprocSubmitTrigger(
9781015
job_id=job_id,
9791016
project_id=self.project_id,
9801017
region=self.region,
@@ -1888,7 +1925,7 @@ def execute(self, context: Context):
18881925
self.job_id = new_job_id
18891926
if self.deferrable:
18901927
self.defer(
1891-
trigger=DataprocBaseTrigger(
1928+
trigger=DataprocSubmitTrigger(
18921929
job_id=self.job_id,
18931930
project_id=self.project_id,
18941931
region=self.region,
@@ -1964,6 +2001,8 @@ class DataprocUpdateClusterOperator(BaseOperator):
19642001
If set as a sequence, the identities from the list must grant
19652002
Service Account Token Creator IAM role to the directly preceding identity, with first
19662003
account from the list granting this role to the originating account (templated).
2004+
:param deferrable: Run operator in the deferrable mode.
2005+
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
19672006
"""
19682007

19692008
template_fields: Sequence[str] = (
@@ -1991,9 +2030,13 @@ def __init__(
19912030
metadata: Sequence[tuple[str, str]] = (),
19922031
gcp_conn_id: str = "google_cloud_default",
19932032
impersonation_chain: str | Sequence[str] | None = None,
2033+
deferrable: bool = False,
2034+
polling_interval_seconds: int = 10,
19942035
**kwargs,
19952036
):
19962037
super().__init__(**kwargs)
2038+
if deferrable and polling_interval_seconds <= 0:
2039+
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
19972040
self.project_id = project_id
19982041
self.region = region
19992042
self.cluster_name = cluster_name
@@ -2006,6 +2049,8 @@ def __init__(
20062049
self.metadata = metadata
20072050
self.gcp_conn_id = gcp_conn_id
20082051
self.impersonation_chain = impersonation_chain
2052+
self.deferrable = deferrable
2053+
self.polling_interval_seconds = polling_interval_seconds
20092054

20102055
def execute(self, context: Context):
20112056
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
@@ -2026,9 +2071,36 @@ def execute(self, context: Context):
20262071
timeout=self.timeout,
20272072
metadata=self.metadata,
20282073
)
2029-
operation.result()
2074+
2075+
if not self.deferrable:
2076+
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
2077+
else:
2078+
self.defer(
2079+
trigger=DataprocClusterTrigger(
2080+
cluster_name=self.cluster_name,
2081+
project_id=self.project_id,
2082+
region=self.region,
2083+
gcp_conn_id=self.gcp_conn_id,
2084+
impersonation_chain=self.impersonation_chain,
2085+
polling_interval_seconds=self.polling_interval_seconds,
2086+
),
2087+
method_name="execute_complete",
2088+
)
20302089
self.log.info("Updated %s cluster.", self.cluster_name)
20312090

2091+
def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
2092+
"""
2093+
Callback for when the trigger fires - returns immediately.
2094+
Relies on trigger to throw an exception, otherwise it assumes execution was
2095+
successful.
2096+
"""
2097+
cluster_state = event["cluster_state"]
2098+
cluster_name = event["cluster_name"]
2099+
2100+
if cluster_state == ClusterStatus.State.ERROR:
2101+
raise AirflowException(f"Cluster is in ERROR state:\n{cluster_name}")
2102+
self.log.info("%s completed successfully.", self.task_id)
2103+
20322104

20332105
class DataprocCreateBatchOperator(BaseOperator):
20342106
"""

β€Žairflow/providers/google/cloud/triggers/dataproc.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@
2020

2121
import asyncio
2222
import warnings
23-
from typing import Sequence
23+
from typing import Any, AsyncIterator, Sequence
2424

25-
from google.cloud.dataproc_v1 import JobStatus
25+
from google.cloud.dataproc_v1 import ClusterStatus, JobStatus
2626

2727
from airflow import AirflowException
2828
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
2929
from airflow.triggers.base import BaseTrigger, TriggerEvent
3030

3131

32-
class DataprocBaseTrigger(BaseTrigger):
32+
class DataprocSubmitTrigger(BaseTrigger):
3333
"""
3434
Trigger that periodically polls information from Dataproc API to verify job status.
3535
Implementation leverages asynchronous transport.
@@ -65,7 +65,7 @@ def __init__(
6565

6666
def serialize(self):
6767
return (
68-
"airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger",
68+
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger",
6969
{
7070
"job_id": self.job_id,
7171
"project_id": self.project_id,
@@ -89,3 +89,63 @@ async def run(self):
8989
raise AirflowException(f"Dataproc job execution failed {self.job_id}")
9090
await asyncio.sleep(self.polling_interval_seconds)
9191
yield TriggerEvent({"job_id": self.job_id, "job_state": state})
92+
93+
94+
class DataprocClusterTrigger(BaseTrigger):
95+
"""
96+
Trigger that periodically polls information from Dataproc API to verify status.
97+
Implementation leverages asynchronous transport.
98+
"""
99+
100+
def __init__(
101+
self,
102+
cluster_name: str,
103+
region: str,
104+
project_id: str | None = None,
105+
gcp_conn_id: str = "google_cloud_default",
106+
impersonation_chain: str | Sequence[str] | None = None,
107+
polling_interval_seconds: int = 10,
108+
):
109+
super().__init__()
110+
self.gcp_conn_id = gcp_conn_id
111+
self.impersonation_chain = impersonation_chain
112+
self.cluster_name = cluster_name
113+
self.project_id = project_id
114+
self.region = region
115+
self.polling_interval_seconds = polling_interval_seconds
116+
117+
def serialize(self) -> tuple[str, dict[str, Any]]:
118+
return (
119+
"airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger",
120+
{
121+
"cluster_name": self.cluster_name,
122+
"project_id": self.project_id,
123+
"region": self.region,
124+
"gcp_conn_id": self.gcp_conn_id,
125+
"impersonation_chain": self.impersonation_chain,
126+
"polling_interval_seconds": self.polling_interval_seconds,
127+
},
128+
)
129+
130+
async def run(self) -> AsyncIterator["TriggerEvent"]:
131+
hook = self._get_hook()
132+
while True:
133+
cluster = await hook.get_cluster(
134+
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
135+
)
136+
state = cluster.status.state
137+
self.log.info("Dataproc cluster: %s is in state: %s", self.cluster_name, state)
138+
if state in (
139+
ClusterStatus.State.ERROR,
140+
ClusterStatus.State.RUNNING,
141+
):
142+
break
143+
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
144+
await asyncio.sleep(self.polling_interval_seconds)
145+
yield TriggerEvent({"cluster_name": self.cluster_name, "cluster_state": state, "cluster": cluster})
146+
147+
def _get_hook(self) -> DataprocAsyncHook:
148+
return DataprocAsyncHook(
149+
gcp_conn_id=self.gcp_conn_id,
150+
impersonation_chain=self.impersonation_chain,
151+
)

β€Ždocs/apache-airflow-providers-google/operators/cloud/dataproc.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ With this configuration we can create the cluster:
7575
:start-after: [START how_to_cloud_dataproc_create_cluster_operator_in_gke]
7676
:end-before: [END how_to_cloud_dataproc_create_cluster_operator_in_gke]
7777

78+
You can use deferrable mode for this action in order to run the operator asynchronously:
79+
80+
.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
81+
:language: python
82+
:dedent: 4
83+
:start-after: [START how_to_cloud_dataproc_create_cluster_operator_async]
84+
:end-before: [END how_to_cloud_dataproc_create_cluster_operator_async]
85+
7886
Generating Cluster Config
7987
^^^^^^^^^^^^^^^^^^^^^^^^^
8088
You can also generate **CLUSTER_CONFIG** using functional API,
@@ -111,6 +119,14 @@ To update a cluster you can use:
111119
:start-after: [START how_to_cloud_dataproc_update_cluster_operator]
112120
:end-before: [END how_to_cloud_dataproc_update_cluster_operator]
113121

122+
You can use deferrable mode for this action in order to run the operator asynchronously:
123+
124+
.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
125+
:language: python
126+
:dedent: 4
127+
:start-after: [START how_to_cloud_dataproc_update_cluster_operator_async]
128+
:end-before: [END how_to_cloud_dataproc_update_cluster_operator_async]
129+
114130
Deleting a cluster
115131
------------------
116132

0 commit comments

Comments
 (0)