Skip to content

Commit 6ef5ba9

Browse files
authored
Refactor Dataproc Trigger (#29364)
1 parent ee0a56a commit 6ef5ba9

File tree

3 files changed

+74
-125
lines changed

3 files changed

+74
-125
lines changed

β€Žairflow/providers/google/cloud/operators/dataproc.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -877,12 +877,10 @@ def execute(self, context: Context) -> None:
877877
project_id=self.project_id,
878878
region=self.region,
879879
cluster_name=self.cluster_name,
880-
request_id=self.request_id,
881-
retry=self.retry,
882880
end_time=end_time,
883881
metadata=self.metadata,
884882
impersonation_chain=self.impersonation_chain,
885-
polling_interval=self.polling_interval_seconds,
883+
polling_interval_seconds=self.polling_interval_seconds,
886884
),
887885
method_name="execute_complete",
888886
)

β€Žairflow/providers/google/cloud/triggers/dataproc.py

Lines changed: 71 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import asyncio
2222
import time
23-
import warnings
2423
from typing import Any, AsyncIterator, Sequence
2524

2625
from google.api_core.exceptions import NotFound
@@ -31,40 +30,58 @@
3130
from airflow.triggers.base import BaseTrigger, TriggerEvent
3231

3332

34-
class DataprocSubmitTrigger(BaseTrigger):
35-
"""
36-
Trigger that periodically polls information from Dataproc API to verify job status.
37-
Implementation leverages asynchronous transport.
38-
"""
33+
class DataprocBaseTrigger(BaseTrigger):
34+
"""Base class for Dataproc triggers"""
3935

4036
def __init__(
4137
self,
42-
job_id: str,
4338
region: str,
4439
project_id: str | None = None,
4540
gcp_conn_id: str = "google_cloud_default",
46-
impersonation_chain: str | Sequence[str] | None = None,
4741
delegate_to: str | None = None,
42+
impersonation_chain: str | Sequence[str] | None = None,
4843
polling_interval_seconds: int = 30,
4944
):
5045
super().__init__()
46+
self.region = region
47+
self.project_id = project_id
5148
self.gcp_conn_id = gcp_conn_id
5249
self.impersonation_chain = impersonation_chain
53-
self.job_id = job_id
54-
self.project_id = project_id
55-
self.region = region
5650
self.polling_interval_seconds = polling_interval_seconds
57-
if delegate_to:
58-
warnings.warn(
59-
"'delegate_to' parameter is deprecated, please use 'impersonation_chain'", DeprecationWarning
60-
)
6151
self.delegate_to = delegate_to
62-
self.hook = DataprocAsyncHook(
63-
delegate_to=self.delegate_to,
52+
53+
def get_async_hook(self):
54+
return DataprocAsyncHook(
6455
gcp_conn_id=self.gcp_conn_id,
6556
impersonation_chain=self.impersonation_chain,
57+
delegate_to=self.delegate_to,
6658
)
6759

60+
61+
class DataprocSubmitTrigger(DataprocBaseTrigger):
62+
"""
63+
DataprocSubmitTrigger run on the trigger worker to perform create Build operation
64+
65+
:param job_id: The ID of a Dataproc job.
66+
:param project_id: Google Cloud Project where the job is running
67+
:param region: The Cloud Dataproc region in which to handle the request.
68+
:param gcp_conn_id: Optional, the connection ID used to connect to Google Cloud Platform.
69+
:param impersonation_chain: Optional service account to impersonate using short-term
70+
credentials, or chained list of accounts required to get the access_token
71+
of the last account in the list, which will be impersonated in the request.
72+
If set as a string, the account must grant the originating account
73+
the Service Account Token Creator IAM role.
74+
If set as a sequence, the identities from the list must grant
75+
Service Account Token Creator IAM role to the directly preceding identity, with first
76+
account from the list granting this role to the originating account (templated).
77+
:param polling_interval_seconds: polling period in seconds to check for the status
78+
"""
79+
80+
def __init__(self, job_id: str, delegate_to: str | None = None, **kwargs):
81+
self.job_id = job_id
82+
self.delegate_to = delegate_to
83+
super().__init__(delegate_to=self.delegate_to, **kwargs)
84+
6885
def serialize(self):
6986
return (
7087
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger",
@@ -81,7 +98,9 @@ def serialize(self):
8198

8299
async def run(self):
83100
while True:
84-
job = await self.hook.get_job(project_id=self.project_id, region=self.region, job_id=self.job_id)
101+
job = await self.get_async_hook().get_job(
102+
project_id=self.project_id, region=self.region, job_id=self.job_id
103+
)
85104
state = job.status.state
86105
self.log.info("Dataproc job: %s is in state: %s", self.job_id, state)
87106
if state in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED):
@@ -93,28 +112,28 @@ async def run(self):
93112
yield TriggerEvent({"job_id": self.job_id, "job_state": state})
94113

95114

96-
class DataprocClusterTrigger(BaseTrigger):
115+
class DataprocClusterTrigger(DataprocBaseTrigger):
97116
"""
98-
Trigger that periodically polls information from Dataproc API to verify status.
99-
Implementation leverages asynchronous transport.
117+
DataprocClusterTrigger run on the trigger worker to perform create Build operation
118+
119+
:param cluster_name: The name of the cluster.
120+
:param project_id: Google Cloud Project where the job is running
121+
:param region: The Cloud Dataproc region in which to handle the request.
122+
:param gcp_conn_id: Optional, the connection ID used to connect to Google Cloud Platform.
123+
:param impersonation_chain: Optional service account to impersonate using short-term
124+
credentials, or chained list of accounts required to get the access_token
125+
of the last account in the list, which will be impersonated in the request.
126+
If set as a string, the account must grant the originating account
127+
the Service Account Token Creator IAM role.
128+
If set as a sequence, the identities from the list must grant
129+
Service Account Token Creator IAM role to the directly preceding identity, with first
130+
account from the list granting this role to the originating account (templated).
131+
:param polling_interval_seconds: polling period in seconds to check for the status
100132
"""
101133

102-
def __init__(
103-
self,
104-
cluster_name: str,
105-
region: str,
106-
project_id: str | None = None,
107-
gcp_conn_id: str = "google_cloud_default",
108-
impersonation_chain: str | Sequence[str] | None = None,
109-
polling_interval_seconds: int = 10,
110-
):
111-
super().__init__()
112-
self.gcp_conn_id = gcp_conn_id
113-
self.impersonation_chain = impersonation_chain
134+
def __init__(self, cluster_name: str, **kwargs):
135+
super().__init__(**kwargs)
114136
self.cluster_name = cluster_name
115-
self.project_id = project_id
116-
self.region = region
117-
self.polling_interval_seconds = polling_interval_seconds
118137

119138
def serialize(self) -> tuple[str, dict[str, Any]]:
120139
return (
@@ -130,9 +149,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
130149
)
131150

132151
async def run(self) -> AsyncIterator["TriggerEvent"]:
133-
hook = self._get_hook()
134152
while True:
135-
cluster = await hook.get_cluster(
153+
cluster = await self.get_async_hook().get_cluster(
136154
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
137155
)
138156
state = cluster.status.state
@@ -146,14 +164,8 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
146164
await asyncio.sleep(self.polling_interval_seconds)
147165
yield TriggerEvent({"cluster_name": self.cluster_name, "cluster_state": state, "cluster": cluster})
148166

149-
def _get_hook(self) -> DataprocAsyncHook:
150-
return DataprocAsyncHook(
151-
gcp_conn_id=self.gcp_conn_id,
152-
impersonation_chain=self.impersonation_chain,
153-
)
154-
155167

156-
class DataprocBatchTrigger(BaseTrigger):
168+
class DataprocBatchTrigger(DataprocBaseTrigger):
157169
"""
158170
DataprocCreateBatchTrigger run on the trigger worker to perform create Build operation
159171
@@ -172,22 +184,9 @@ class DataprocBatchTrigger(BaseTrigger):
172184
:param polling_interval_seconds: polling period in seconds to check for the status
173185
"""
174186

175-
def __init__(
176-
self,
177-
batch_id: str,
178-
region: str,
179-
project_id: str | None,
180-
gcp_conn_id: str = "google_cloud_default",
181-
impersonation_chain: str | Sequence[str] | None = None,
182-
polling_interval_seconds: float = 5.0,
183-
):
184-
super().__init__()
187+
def __init__(self, batch_id: str, **kwargs):
188+
super().__init__(**kwargs)
185189
self.batch_id = batch_id
186-
self.project_id = project_id
187-
self.region = region
188-
self.gcp_conn_id = gcp_conn_id
189-
self.impersonation_chain = impersonation_chain
190-
self.polling_interval_seconds = polling_interval_seconds
191190

192191
def serialize(self) -> tuple[str, dict[str, Any]]:
193192
"""Serializes DataprocBatchTrigger arguments and classpath."""
@@ -204,13 +203,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
204203
)
205204

206205
async def run(self):
207-
hook = DataprocAsyncHook(
208-
gcp_conn_id=self.gcp_conn_id,
209-
impersonation_chain=self.impersonation_chain,
210-
)
211-
212206
while True:
213-
batch = await hook.get_batch(
207+
batch = await self.get_async_hook().get_batch(
214208
project_id=self.project_id, region=self.region, batch_id=self.batch_id
215209
)
216210
state = batch.state
@@ -223,9 +217,9 @@ async def run(self):
223217
yield TriggerEvent({"batch_id": self.batch_id, "batch_state": state})
224218

225219

226-
class DataprocDeleteClusterTrigger(BaseTrigger):
220+
class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
227221
"""
228-
Asynchronously checks the status of a cluster.
222+
DataprocDeleteClusterTrigger run on the trigger worker to perform delete cluster operation.
229223
230224
:param cluster_name: The name of the cluster
231225
:param end_time: Time in second left to check the cluster status
@@ -241,30 +235,20 @@ class DataprocDeleteClusterTrigger(BaseTrigger):
241235
If set as a sequence, the identities from the list must grant
242236
Service Account Token Creator IAM role to the directly preceding identity, with first
243237
account from the list granting this role to the originating account.
244-
:param polling_interval: Time in seconds to sleep between checks of cluster status
238+
:param polling_interval_seconds: Time in seconds to sleep between checks of cluster status
245239
"""
246240

247241
def __init__(
248242
self,
249243
cluster_name: str,
250244
end_time: float,
251-
project_id: str | None = None,
252-
region: str | None = None,
253245
metadata: Sequence[tuple[str, str]] = (),
254-
gcp_conn_id: str = "google_cloud_default",
255-
impersonation_chain: str | Sequence[str] | None = None,
256-
polling_interval: float = 5.0,
257246
**kwargs: Any,
258247
):
259248
super().__init__(**kwargs)
260249
self.cluster_name = cluster_name
261250
self.end_time = end_time
262-
self.project_id = project_id
263-
self.region = region
264251
self.metadata = metadata
265-
self.gcp_conn_id = gcp_conn_id
266-
self.impersonation_chain = impersonation_chain
267-
self.polling_interval = polling_interval
268252

269253
def serialize(self) -> tuple[str, dict[str, Any]]:
270254
"""Serializes DataprocDeleteClusterTrigger arguments and classpath."""
@@ -278,16 +262,15 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
278262
"metadata": self.metadata,
279263
"gcp_conn_id": self.gcp_conn_id,
280264
"impersonation_chain": self.impersonation_chain,
281-
"polling_interval": self.polling_interval,
265+
"polling_interval_seconds": self.polling_interval_seconds,
282266
},
283267
)
284268

285269
async def run(self) -> AsyncIterator["TriggerEvent"]:
286270
"""Wait until cluster is deleted completely"""
287-
hook = self._get_hook()
288271
while self.end_time > time.time():
289272
try:
290-
cluster = await hook.get_cluster(
273+
cluster = await self.get_async_hook().get_cluster(
291274
region=self.region, # type: ignore[arg-type]
292275
cluster_name=self.cluster_name,
293276
project_id=self.project_id, # type: ignore[arg-type]
@@ -296,52 +279,26 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
296279
self.log.info(
297280
"Cluster status is %s. Sleeping for %s seconds.",
298281
cluster.status.state,
299-
self.polling_interval,
282+
self.polling_interval_seconds,
300283
)
301-
await asyncio.sleep(self.polling_interval)
284+
await asyncio.sleep(self.polling_interval_seconds)
302285
except NotFound:
303286
yield TriggerEvent({"status": "success", "message": ""})
304287
except Exception as e:
305288
yield TriggerEvent({"status": "error", "message": str(e)})
306289
yield TriggerEvent({"status": "error", "message": "Timeout"})
307290

308-
def _get_hook(self) -> DataprocAsyncHook:
309-
return DataprocAsyncHook(
310-
gcp_conn_id=self.gcp_conn_id,
311-
impersonation_chain=self.impersonation_chain,
312-
)
313-
314291

315-
class DataprocWorkflowTrigger(BaseTrigger):
292+
class DataprocWorkflowTrigger(DataprocBaseTrigger):
316293
"""
317294
Trigger that periodically polls information from Dataproc API to verify status.
318295
Implementation leverages asynchronous transport.
319296
"""
320297

321-
def __init__(
322-
self,
323-
template_name: str,
324-
name: str,
325-
region: str,
326-
project_id: str | None = None,
327-
gcp_conn_id: str = "google_cloud_default",
328-
impersonation_chain: str | Sequence[str] | None = None,
329-
delegate_to: str | None = None,
330-
polling_interval_seconds: int = 10,
331-
):
332-
super().__init__()
333-
self.gcp_conn_id = gcp_conn_id
298+
def __init__(self, template_name: str, name: str, **kwargs: Any):
299+
super().__init__(**kwargs)
334300
self.template_name = template_name
335301
self.name = name
336-
self.impersonation_chain = impersonation_chain
337-
self.project_id = project_id
338-
self.region = region
339-
self.polling_interval_seconds = polling_interval_seconds
340-
self.delegate_to = delegate_to
341-
if delegate_to:
342-
warnings.warn(
343-
"'delegate_to' parameter is deprecated, please use 'impersonation_chain'", DeprecationWarning
344-
)
345302

346303
def serialize(self):
347304
return (
@@ -359,7 +316,7 @@ def serialize(self):
359316
)
360317

361318
async def run(self) -> AsyncIterator["TriggerEvent"]:
362-
hook = self._get_hook()
319+
hook = self.get_async_hook()
363320
while True:
364321
try:
365322
operation = await hook.get_operation(region=self.region, operation_name=self.name)
@@ -394,9 +351,3 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
394351
"message": str(e),
395352
}
396353
)
397-
398-
def _get_hook(self) -> DataprocAsyncHook: # type: ignore[override]
399-
return DataprocAsyncHook(
400-
gcp_conn_id=self.gcp_conn_id,
401-
impersonation_chain=self.impersonation_chain,
402-
)

β€Žtests/providers/google/cloud/triggers/test_dataproc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, w
302302
}
303303

304304
@pytest.mark.asyncio
305-
@async_mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocWorkflowTrigger._get_hook")
305+
@async_mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
306306
async def test_async_workflow_triggers_on_success_should_execute_successfully(
307307
self, mock_hook, workflow_trigger, async_get_operation
308308
):
@@ -322,7 +322,7 @@ async def test_async_workflow_triggers_on_success_should_execute_successfully(
322322
assert expected_event == actual_event
323323

324324
@pytest.mark.asyncio
325-
@async_mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocWorkflowTrigger._get_hook")
325+
@async_mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
326326
async def test_async_workflow_triggers_on_error(self, mock_hook, workflow_trigger, async_get_operation):
327327
mock_hook.return_value.get_operation.return_value = async_get_operation(
328328
name=TEST_OPERATION_NAME, done=True, response={}, error=Status(message="test_error")

0 commit comments

Comments
 (0)