Skip to content

Commit 9d93517

Browse files
authored
Add deferrable mode to DataprocCreateBatchOperator (#28457)
1 parent 5503587 commit 9d93517

File tree

6 files changed

+385
-12
lines changed

6 files changed

+385
-12
lines changed

β€Žairflow/providers/google/cloud/operators/dataproc.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@
5050
DataprocLink,
5151
DataprocListLink,
5252
)
53-
from airflow.providers.google.cloud.triggers.dataproc import DataprocClusterTrigger, DataprocSubmitTrigger
53+
from airflow.providers.google.cloud.triggers.dataproc import (
54+
DataprocBatchTrigger,
55+
DataprocClusterTrigger,
56+
DataprocSubmitTrigger,
57+
)
5458
from airflow.utils import timezone
5559

5660
if TYPE_CHECKING:
@@ -2134,6 +2138,8 @@ class DataprocCreateBatchOperator(BaseOperator):
21342138
:param asynchronous: Flag to return after creating batch to the Dataproc API.
21352139
This is useful for creating long-running batch and
21362140
waiting on them asynchronously using the DataprocBatchSensor
2141+
:param deferrable: Run operator in the deferrable mode.
2142+
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
21372143
"""
21382144

21392145
template_fields: Sequence[str] = (
@@ -2151,7 +2157,7 @@ def __init__(
21512157
region: str | None = None,
21522158
project_id: str | None = None,
21532159
batch: dict | Batch,
2154-
batch_id: str | None = None,
2160+
batch_id: str,
21552161
request_id: str | None = None,
21562162
retry: Retry | _MethodDefault = DEFAULT,
21572163
timeout: float | None = None,
@@ -2160,9 +2166,13 @@ def __init__(
21602166
impersonation_chain: str | Sequence[str] | None = None,
21612167
result_retry: Retry | _MethodDefault = DEFAULT,
21622168
asynchronous: bool = False,
2169+
deferrable: bool = False,
2170+
polling_interval_seconds: int = 5,
21632171
**kwargs,
21642172
):
21652173
super().__init__(**kwargs)
2174+
if deferrable and polling_interval_seconds <= 0:
2175+
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
21662176
self.region = region
21672177
self.project_id = project_id
21682178
self.batch = batch
@@ -2176,6 +2186,8 @@ def __init__(
21762186
self.impersonation_chain = impersonation_chain
21772187
self.operation: operation.Operation | None = None
21782188
self.asynchronous = asynchronous
2189+
self.deferrable = deferrable
2190+
self.polling_interval_seconds = polling_interval_seconds
21792191

21802192
def execute(self, context: Context):
21812193
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
@@ -2195,13 +2207,30 @@ def execute(self, context: Context):
21952207
)
21962208
if self.operation is None:
21972209
raise RuntimeError("The operation should be set here!")
2198-
if not self.asynchronous:
2199-
result = hook.wait_for_operation(
2200-
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
2201-
)
2202-
self.log.info("Batch %s created", self.batch_id)
2210+
2211+
if not self.deferrable:
2212+
if not self.asynchronous:
2213+
result = hook.wait_for_operation(
2214+
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
2215+
)
2216+
self.log.info("Batch %s created", self.batch_id)
2217+
2218+
else:
2219+
return self.operation.operation.name
2220+
22032221
else:
2204-
return self.operation.operation.name
2222+
self.defer(
2223+
trigger=DataprocBatchTrigger(
2224+
batch_id=self.batch_id,
2225+
project_id=self.project_id,
2226+
region=self.region,
2227+
gcp_conn_id=self.gcp_conn_id,
2228+
impersonation_chain=self.impersonation_chain,
2229+
polling_interval_seconds=self.polling_interval_seconds,
2230+
),
2231+
method_name="execute_complete",
2232+
)
2233+
22052234
except AlreadyExists:
22062235
self.log.info("Batch with given id already exists")
22072236
if self.batch_id is None:
@@ -2233,6 +2262,23 @@ def execute(self, context: Context):
22332262
DataprocLink.persist(context=context, task_instance=self, url=DATAPROC_BATCH_LINK, resource=batch_id)
22342263
return Batch.to_dict(result)
22352264

2265+
def execute_complete(self, context, event=None) -> None:
2266+
"""
2267+
Callback for when the trigger fires - returns immediately.
2268+
Relies on trigger to throw an exception, otherwise it assumes execution was
2269+
successful.
2270+
"""
2271+
if event is None:
2272+
raise AirflowException("Batch failed.")
2273+
batch_state = event["batch_state"]
2274+
batch_id = event["batch_id"]
2275+
2276+
if batch_state == Batch.State.FAILED:
2277+
raise AirflowException(f"Batch failed:\n{batch_id}")
2278+
if batch_state == Batch.State.CANCELLED:
2279+
raise AirflowException(f"Batch was cancelled:\n{batch_id}")
2280+
self.log.info("%s completed successfully.", self.task_id)
2281+
22362282
def on_kill(self):
22372283
if self.operation:
22382284
self.operation.cancel()

β€Žairflow/providers/google/cloud/triggers/dataproc.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import warnings
2323
from typing import Any, AsyncIterator, Sequence
2424

25-
from google.cloud.dataproc_v1 import ClusterStatus, JobStatus
25+
from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
2626

2727
from airflow import AirflowException
2828
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
@@ -149,3 +149,73 @@ def _get_hook(self) -> DataprocAsyncHook:
149149
gcp_conn_id=self.gcp_conn_id,
150150
impersonation_chain=self.impersonation_chain,
151151
)
152+
153+
154+
class DataprocBatchTrigger(BaseTrigger):
155+
"""
156+
DataprocCreateBatchTrigger run on the trigger worker to perform create Build operation
157+
158+
:param batch_id: The ID of the build.
159+
:param project_id: Google Cloud Project where the job is running
160+
:param region: The Cloud Dataproc region in which to handle the request.
161+
:param gcp_conn_id: Optional, the connection ID used to connect to Google Cloud Platform.
162+
:param impersonation_chain: Optional service account to impersonate using short-term
163+
credentials, or chained list of accounts required to get the access_token
164+
of the last account in the list, which will be impersonated in the request.
165+
If set as a string, the account must grant the originating account
166+
the Service Account Token Creator IAM role.
167+
If set as a sequence, the identities from the list must grant
168+
Service Account Token Creator IAM role to the directly preceding identity, with first
169+
account from the list granting this role to the originating account (templated).
170+
:param polling_interval_seconds: polling period in seconds to check for the status
171+
"""
172+
173+
def __init__(
174+
self,
175+
batch_id: str,
176+
region: str,
177+
project_id: str | None,
178+
gcp_conn_id: str = "google_cloud_default",
179+
impersonation_chain: str | Sequence[str] | None = None,
180+
polling_interval_seconds: float = 5.0,
181+
):
182+
super().__init__()
183+
self.batch_id = batch_id
184+
self.project_id = project_id
185+
self.region = region
186+
self.gcp_conn_id = gcp_conn_id
187+
self.impersonation_chain = impersonation_chain
188+
self.polling_interval_seconds = polling_interval_seconds
189+
190+
def serialize(self) -> tuple[str, dict[str, Any]]:
191+
"""Serializes DataprocBatchTrigger arguments and classpath."""
192+
return (
193+
"airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger",
194+
{
195+
"batch_id": self.batch_id,
196+
"project_id": self.project_id,
197+
"region": self.region,
198+
"gcp_conn_id": self.gcp_conn_id,
199+
"impersonation_chain": self.impersonation_chain,
200+
"polling_interval_seconds": self.polling_interval_seconds,
201+
},
202+
)
203+
204+
async def run(self):
205+
hook = DataprocAsyncHook(
206+
gcp_conn_id=self.gcp_conn_id,
207+
impersonation_chain=self.impersonation_chain,
208+
)
209+
210+
while True:
211+
batch = await hook.get_batch(
212+
project_id=self.project_id, region=self.region, batch_id=self.batch_id
213+
)
214+
state = batch.state
215+
216+
if state in (Batch.State.FAILED, Batch.State.SUCCEEDED, Batch.State.CANCELLED):
217+
break
218+
self.log.info("Current state is %s", state)
219+
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
220+
await asyncio.sleep(self.polling_interval_seconds)
221+
yield TriggerEvent({"batch_id": self.batch_id, "batch_state": state})

β€Ždocs/apache-airflow-providers-google/operators/cloud/dataproc.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,14 @@ To check if operation succeeded you can use
305305
:start-after: [START how_to_cloud_dataproc_batch_async_sensor]
306306
:end-before: [END how_to_cloud_dataproc_batch_async_sensor]
307307

308+
Also for all this action you can use operator in the deferrable mode:
309+
310+
.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_batch_deferrable.py
311+
:language: python
312+
:dedent: 4
313+
:start-after: [START how_to_cloud_dataproc_create_batch_operator_async]
314+
:end-before: [END how_to_cloud_dataproc_create_batch_operator_async]
315+
308316
Get a Batch
309317
-----------
310318

β€Žtests/providers/google/cloud/operators/test_dataproc.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@
5454
DataprocSubmitSparkSqlJobOperator,
5555
DataprocUpdateClusterOperator,
5656
)
57-
from airflow.providers.google.cloud.triggers.dataproc import DataprocClusterTrigger, DataprocSubmitTrigger
57+
from airflow.providers.google.cloud.triggers.dataproc import (
58+
DataprocBatchTrigger,
59+
DataprocClusterTrigger,
60+
DataprocSubmitTrigger,
61+
)
5862
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
5963
from airflow.serialization.serialized_objects import SerializedDAG
6064
from airflow.utils.timezone import datetime
@@ -2032,3 +2036,44 @@ def test_execute(self, mock_hook):
20322036
timeout=TIMEOUT,
20332037
metadata=METADATA,
20342038
)
2039+
2040+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
2041+
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
2042+
def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
2043+
mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID
2044+
2045+
op = DataprocCreateBatchOperator(
2046+
task_id=TASK_ID,
2047+
region=GCP_REGION,
2048+
project_id=GCP_PROJECT,
2049+
batch=BATCH,
2050+
batch_id="batch_id",
2051+
gcp_conn_id=GCP_CONN_ID,
2052+
retry=RETRY,
2053+
timeout=TIMEOUT,
2054+
metadata=METADATA,
2055+
request_id=REQUEST_ID,
2056+
impersonation_chain=IMPERSONATION_CHAIN,
2057+
deferrable=True,
2058+
)
2059+
with pytest.raises(TaskDeferred) as exc:
2060+
op.execute(mock.MagicMock())
2061+
2062+
mock_hook.assert_called_once_with(
2063+
gcp_conn_id=GCP_CONN_ID,
2064+
impersonation_chain=IMPERSONATION_CHAIN,
2065+
)
2066+
mock_hook.return_value.create_batch.assert_called_once_with(
2067+
region=GCP_REGION,
2068+
project_id=GCP_PROJECT,
2069+
batch_id="batch_id",
2070+
batch=BATCH,
2071+
request_id=REQUEST_ID,
2072+
retry=RETRY,
2073+
timeout=TIMEOUT,
2074+
metadata=METADATA,
2075+
)
2076+
mock_hook.return_value.wait_for_job.assert_not_called()
2077+
2078+
assert isinstance(exc.value.trigger, DataprocBatchTrigger)
2079+
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME

0 commit comments

Comments
 (0)