Skip to content

Commit dc3a3c7

Browse files
bkossakowskaBeata Kossakowska
andauthored
Add DataprocCancelOperationOperator (#28456)
Co-authored-by: Beata Kossakowska <bkossakowska@google.com>
1 parent d24527b commit dc3a3c7

File tree

6 files changed

+298
-10
lines changed

6 files changed

+298
-10
lines changed

β€Žairflow/providers/google/cloud/hooks/dataproc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,10 @@ def get_batch_client(self, region: str | None = None) -> BatchControllerClient:
252252
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
253253
)
254254

255+
def get_operations_client(self, region):
256+
"""Returns OperationsClient"""
257+
return self.get_batch_client(region=region).transport.operations_client
258+
255259
def wait_for_operation(
256260
self,
257261
operation: Operation,

β€Žairflow/providers/google/cloud/operators/dataproc.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,6 +2055,9 @@ class DataprocCreateBatchOperator(BaseOperator):
20552055
If set as a sequence, the identities from the list must grant
20562056
Service Account Token Creator IAM role to the directly preceding identity, with first
20572057
account from the list granting this role to the originating account (templated).
2058+
:param asynchronous: Flag to return after creating batch to the Dataproc API.
2059+
This is useful for creating long-running batch and
2060+
waiting on them asynchronously using the DataprocBatchSensor
20582061
"""
20592062

20602063
template_fields: Sequence[str] = (
@@ -2080,6 +2083,7 @@ def __init__(
20802083
gcp_conn_id: str = "google_cloud_default",
20812084
impersonation_chain: str | Sequence[str] | None = None,
20822085
result_retry: Retry | _MethodDefault = DEFAULT,
2086+
asynchronous: bool = False,
20832087
**kwargs,
20842088
):
20852089
super().__init__(**kwargs)
@@ -2095,6 +2099,7 @@ def __init__(
20952099
self.gcp_conn_id = gcp_conn_id
20962100
self.impersonation_chain = impersonation_chain
20972101
self.operation: operation.Operation | None = None
2102+
self.asynchronous = asynchronous
20982103

20992104
def execute(self, context: Context):
21002105
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
@@ -2114,10 +2119,13 @@ def execute(self, context: Context):
21142119
)
21152120
if self.operation is None:
21162121
raise RuntimeError("The operation should be set here!")
2117-
result = hook.wait_for_operation(
2118-
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
2119-
)
2120-
self.log.info("Batch %s created", self.batch_id)
2122+
if not self.asynchronous:
2123+
result = hook.wait_for_operation(
2124+
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
2125+
)
2126+
self.log.info("Batch %s created", self.batch_id)
2127+
else:
2128+
return self.operation.operation.name
21212129
except AlreadyExists:
21222130
self.log.info("Batch with given id already exists")
21232131
if self.batch_id is None:
@@ -2130,7 +2138,6 @@ def execute(self, context: Context):
21302138
timeout=self.timeout,
21312139
metadata=self.metadata,
21322140
)
2133-
21342141
# The existing batch may be a number of states other than 'SUCCEEDED'
21352142
if result.state != Batch.State.SUCCEEDED:
21362143
if result.state == Batch.State.FAILED or result.state == Batch.State.CANCELLED:
@@ -2355,3 +2362,59 @@ def execute(self, context: Context):
23552362
)
23562363
DataprocListLink.persist(context=context, task_instance=self, url=DATAPROC_BATCHES_LINK)
23572364
return [Batch.to_dict(result) for result in results]
2365+
2366+
2367+
class DataprocCancelOperationOperator(BaseOperator):
2368+
"""
2369+
Cancel the batch workload resource.
2370+
2371+
:param operation_name: Required. The name of the operation resource to be cancelled.
2372+
:param region: Required. The Cloud Dataproc region in which to handle the request.
2373+
:param project_id: Optional. The ID of the Google Cloud project that the cluster belongs to.
2374+
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
2375+
retried.
2376+
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
2377+
``retry`` is specified, the timeout applies to each individual attempt.
2378+
:param metadata: Additional metadata that is provided to the method.
2379+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
2380+
:param impersonation_chain: Optional service account to impersonate using short-term
2381+
credentials, or chained list of accounts required to get the access_token
2382+
of the last account in the list, which will be impersonated in the request.
2383+
If set as a string, the account must grant the originating account
2384+
the Service Account Token Creator IAM role.
2385+
If set as a sequence, the identities from the list must grant
2386+
Service Account Token Creator IAM role to the directly preceding identity, with first
2387+
account from the list granting this role to the originating account (templated).
2388+
"""
2389+
2390+
template_fields: Sequence[str] = ("operation_name", "region", "project_id", "impersonation_chain")
2391+
2392+
def __init__(
2393+
self,
2394+
*,
2395+
operation_name: str,
2396+
region: str,
2397+
project_id: str | None = None,
2398+
retry: Retry | _MethodDefault = DEFAULT,
2399+
timeout: float | None = None,
2400+
metadata: Sequence[tuple[str, str]] = (),
2401+
gcp_conn_id: str = "google_cloud_default",
2402+
impersonation_chain: str | Sequence[str] | None = None,
2403+
**kwargs,
2404+
):
2405+
super().__init__(**kwargs)
2406+
self.operation_name = operation_name
2407+
self.region = region
2408+
self.project_id = project_id
2409+
self.retry = retry
2410+
self.timeout = timeout
2411+
self.metadata = metadata
2412+
self.gcp_conn_id = gcp_conn_id
2413+
self.impersonation_chain = impersonation_chain
2414+
2415+
def execute(self, context: Context):
2416+
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
2417+
2418+
self.log.info("Canceling operation: %s", self.operation_name)
2419+
hook.get_operations_client(region=self.region).cancel_operation(name=self.operation_name)
2420+
self.log.info("Operation canceled.")

β€Žairflow/providers/google/cloud/sensors/dataproc.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import TYPE_CHECKING, Sequence
2323

2424
from google.api_core.exceptions import ServerError
25-
from google.cloud.dataproc_v1.types import JobStatus
25+
from google.cloud.dataproc_v1.types import Batch, JobStatus
2626

2727
from airflow.exceptions import AirflowException
2828
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
@@ -109,3 +109,77 @@ def poke(self, context: Context) -> bool:
109109

110110
self.log.info("Waiting for job %s to complete.", self.dataproc_job_id)
111111
return False
112+
113+
114+
class DataprocBatchSensor(BaseSensorOperator):
115+
"""
116+
Check for the state of batch.
117+
118+
:param batch_id: The Dataproc batch ID to poll. (templated)
119+
:param region: Required. The Cloud Dataproc region in which to handle the request. (templated)
120+
:param project_id: The ID of the google cloud project in which
121+
to create the cluster. (templated)
122+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud Platform.
123+
:param wait_timeout: How many seconds wait for job to be ready.
124+
"""
125+
126+
template_fields: Sequence[str] = ("project_id", "region", "batch_id")
127+
ui_color = "#f0eee4"
128+
129+
def __init__(
130+
self,
131+
*,
132+
batch_id: str,
133+
region: str,
134+
project_id: str | None = None,
135+
gcp_conn_id: str = "google_cloud_default",
136+
wait_timeout: int | None = None,
137+
**kwargs,
138+
) -> None:
139+
super().__init__(**kwargs)
140+
self.batch_id = batch_id
141+
self.project_id = project_id
142+
self.gcp_conn_id = gcp_conn_id
143+
self.region = region
144+
self.wait_timeout = wait_timeout
145+
self.start_sensor_time: float | None = None
146+
147+
def execute(self, context: Context) -> None:
148+
self.start_sensor_time = time.monotonic()
149+
super().execute(context)
150+
151+
def _duration(self):
152+
return time.monotonic() - self.start_sensor_time
153+
154+
def poke(self, context: Context) -> bool:
155+
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
156+
if self.wait_timeout:
157+
try:
158+
batch = hook.get_batch(batch_id=self.batch_id, region=self.region, project_id=self.project_id)
159+
except ServerError as err:
160+
duration = self._duration()
161+
self.log.info("DURATION RUN: %f", duration)
162+
163+
if duration > self.wait_timeout:
164+
raise AirflowException(
165+
f"Timeout: dataproc batch {self.batch_id} is not ready after {self.wait_timeout}s"
166+
)
167+
self.log.info("Retrying. Dataproc API returned server error when waiting for batch: %s", err)
168+
return False
169+
else:
170+
batch = hook.get_batch(batch_id=self.batch_id, region=self.region, project_id=self.project_id)
171+
172+
state = batch.state
173+
if state == Batch.State.FAILED:
174+
raise AirflowException("Batch failed")
175+
elif state in {
176+
Batch.State.CANCELLED,
177+
Batch.State.CANCELLING,
178+
}:
179+
raise AirflowException("Batch was cancelled.")
180+
elif state == Batch.State.SUCCEEDED:
181+
self.log.debug("Batch %s completed successfully.", self.batch_id)
182+
return True
183+
184+
self.log.info("Waiting for the batch %s to complete.", self.batch_id)
185+
return False

β€Ždocs/apache-airflow-providers-google/operators/cloud/dataproc.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,16 @@ After Cluster was created you should add it to the Batch configuration.
279279
:start-after: [START how_to_cloud_dataproc_create_batch_operator_with_persistent_history_server]
280280
:end-before: [END how_to_cloud_dataproc_create_batch_operator_with_persistent_history_server]
281281

282+
To check if operation succeeded you can use
283+
284+
:class:`~airflow.providers.google.cloud.sensors.dataproc.DataprocBatchSensor`.
285+
286+
.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py
287+
:language: python
288+
:dedent: 4
289+
:start-after: [START how_to_cloud_dataproc_batch_async_sensor]
290+
:end-before: [END how_to_cloud_dataproc_batch_async_sensor]
291+
282292
Get a Batch
283293
-----------
284294

@@ -315,6 +325,18 @@ To delete a batch you can use:
315325
:start-after: [START how_to_cloud_dataproc_delete_batch_operator]
316326
:end-before: [END how_to_cloud_dataproc_delete_batch_operator]
317327

328+
Cancel a Batch Operation
329+
------------------------
330+
331+
To cancel a operation you can use:
332+
:class: ``~airflow.providers.google.cloud.operators.dataproc.DataprocCancelOperationOperator``.
333+
334+
.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_batch.py
335+
:language: python
336+
:dedent: 4
337+
:start-after: [START how_to_cloud_dataproc_cancel_operation_operator]
338+
:end-before: [END how_to_cloud_dataproc_cancel_operation_operator]
339+
318340
References
319341
^^^^^^^^^^
320342
For further information, take a look at:

β€Žtests/providers/google/cloud/sensors/test_dataproc.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222

2323
import pytest
2424
from google.api_core.exceptions import ServerError
25-
from google.cloud.dataproc_v1.types import JobStatus
25+
from google.cloud.dataproc_v1.types import Batch, JobStatus
2626

2727
from airflow import AirflowException
28-
from airflow.providers.google.cloud.sensors.dataproc import DataprocJobSensor
28+
from airflow.providers.google.cloud.sensors.dataproc import DataprocBatchSensor, DataprocJobSensor
2929
from airflow.version import version as airflow_version
3030

3131
AIRFLOW_VERSION = "v" + airflow_version.replace(".", "-").replace("+", "-")
@@ -184,3 +184,72 @@ def test_wait_timeout_raise_exception(self, mock_hook):
184184

185185
with pytest.raises(AirflowException, match="Timeout: dataproc job job_id is not ready after 300s"):
186186
sensor.poke(context={})
187+
188+
189+
class TestDataprocBatchSensor(unittest.TestCase):
190+
def create_batch(self, state: int):
191+
batch = mock.Mock()
192+
batch.state = mock.Mock()
193+
batch.state = state
194+
return batch
195+
196+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
197+
def test_succeeded(self, mock_hook):
198+
batch = self.create_batch(Batch.State.SUCCEEDED)
199+
mock_hook.return_value.get_batch.return_value = batch
200+
201+
sensor = DataprocBatchSensor(
202+
task_id=TASK_ID,
203+
region=GCP_LOCATION,
204+
project_id=GCP_PROJECT,
205+
batch_id="batch_id",
206+
poke_interval=10,
207+
gcp_conn_id=GCP_CONN_ID,
208+
timeout=TIMEOUT,
209+
)
210+
ret = sensor.poke(context={})
211+
mock_hook.return_value.get_batch.assert_called_once_with(
212+
batch_id="batch_id", region=GCP_LOCATION, project_id=GCP_PROJECT
213+
)
214+
assert ret
215+
216+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
217+
def test_cancelled(self, mock_hook):
218+
batch = self.create_batch(Batch.State.CANCELLED)
219+
mock_hook.return_value.get_batch.return_value = batch
220+
221+
sensor = DataprocBatchSensor(
222+
task_id=TASK_ID,
223+
region=GCP_LOCATION,
224+
project_id=GCP_PROJECT,
225+
batch_id="batch_id",
226+
gcp_conn_id=GCP_CONN_ID,
227+
timeout=TIMEOUT,
228+
)
229+
with pytest.raises(AirflowException, match="Batch was cancelled."):
230+
sensor.poke(context={})
231+
232+
mock_hook.return_value.get_batch.assert_called_once_with(
233+
batch_id="batch_id", region=GCP_LOCATION, project_id=GCP_PROJECT
234+
)
235+
236+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
237+
def test_error(self, mock_hook):
238+
batch = self.create_batch(Batch.State.FAILED)
239+
mock_hook.return_value.get_batch.return_value = batch
240+
241+
sensor = DataprocBatchSensor(
242+
task_id=TASK_ID,
243+
region=GCP_LOCATION,
244+
project_id=GCP_PROJECT,
245+
batch_id="batch_id",
246+
gcp_conn_id=GCP_CONN_ID,
247+
timeout=TIMEOUT,
248+
)
249+
250+
with pytest.raises(AirflowException, match="Batch failed"):
251+
sensor.poke(context={})
252+
253+
mock_hook.return_value.get_batch.assert_called_once_with(
254+
batch_id="batch_id", region=GCP_LOCATION, project_id=GCP_PROJECT
255+
)

0 commit comments

Comments
 (0)