20
20
21
21
import asyncio
22
22
import time
23
- import warnings
24
23
from typing import Any , AsyncIterator , Sequence
25
24
26
25
from google .api_core .exceptions import NotFound
31
30
from airflow .triggers .base import BaseTrigger , TriggerEvent
32
31
33
32
34
- class DataprocSubmitTrigger (BaseTrigger ):
35
- """
36
- Trigger that periodically polls information from Dataproc API to verify job status.
37
- Implementation leverages asynchronous transport.
38
- """
33
+ class DataprocBaseTrigger (BaseTrigger ):
34
+ """Base class for Dataproc triggers"""
39
35
40
36
def __init__ (
41
37
self ,
42
- job_id : str ,
43
38
region : str ,
44
39
project_id : str | None = None ,
45
40
gcp_conn_id : str = "google_cloud_default" ,
46
- impersonation_chain : str | Sequence [str ] | None = None ,
47
41
delegate_to : str | None = None ,
42
+ impersonation_chain : str | Sequence [str ] | None = None ,
48
43
polling_interval_seconds : int = 30 ,
49
44
):
50
45
super ().__init__ ()
46
+ self .region = region
47
+ self .project_id = project_id
51
48
self .gcp_conn_id = gcp_conn_id
52
49
self .impersonation_chain = impersonation_chain
53
- self .job_id = job_id
54
- self .project_id = project_id
55
- self .region = region
56
50
self .polling_interval_seconds = polling_interval_seconds
57
- if delegate_to :
58
- warnings .warn (
59
- "'delegate_to' parameter is deprecated, please use 'impersonation_chain'" , DeprecationWarning
60
- )
61
51
self .delegate_to = delegate_to
62
- self .hook = DataprocAsyncHook (
63
- delegate_to = self .delegate_to ,
52
+
53
+ def get_async_hook (self ):
54
+ return DataprocAsyncHook (
64
55
gcp_conn_id = self .gcp_conn_id ,
65
56
impersonation_chain = self .impersonation_chain ,
57
+ delegate_to = self .delegate_to ,
66
58
)
67
59
60
+
61
+ class DataprocSubmitTrigger (DataprocBaseTrigger ):
62
+ """
63
+ DataprocSubmitTrigger run on the trigger worker to perform create Build operation
64
+
65
+ :param job_id: The ID of a Dataproc job.
66
+ :param project_id: Google Cloud Project where the job is running
67
+ :param region: The Cloud Dataproc region in which to handle the request.
68
+ :param gcp_conn_id: Optional, the connection ID used to connect to Google Cloud Platform.
69
+ :param impersonation_chain: Optional service account to impersonate using short-term
70
+ credentials, or chained list of accounts required to get the access_token
71
+ of the last account in the list, which will be impersonated in the request.
72
+ If set as a string, the account must grant the originating account
73
+ the Service Account Token Creator IAM role.
74
+ If set as a sequence, the identities from the list must grant
75
+ Service Account Token Creator IAM role to the directly preceding identity, with first
76
+ account from the list granting this role to the originating account (templated).
77
+ :param polling_interval_seconds: polling period in seconds to check for the status
78
+ """
79
+
80
+ def __init__ (self , job_id : str , delegate_to : str | None = None , ** kwargs ):
81
+ self .job_id = job_id
82
+ self .delegate_to = delegate_to
83
+ super ().__init__ (delegate_to = self .delegate_to , ** kwargs )
84
+
68
85
def serialize (self ):
69
86
return (
70
87
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger" ,
@@ -81,7 +98,9 @@ def serialize(self):
81
98
82
99
async def run (self ):
83
100
while True :
84
- job = await self .hook .get_job (project_id = self .project_id , region = self .region , job_id = self .job_id )
101
+ job = await self .get_async_hook ().get_job (
102
+ project_id = self .project_id , region = self .region , job_id = self .job_id
103
+ )
85
104
state = job .status .state
86
105
self .log .info ("Dataproc job: %s is in state: %s" , self .job_id , state )
87
106
if state in (JobStatus .State .ERROR , JobStatus .State .DONE , JobStatus .State .CANCELLED ):
@@ -93,28 +112,28 @@ async def run(self):
93
112
yield TriggerEvent ({"job_id" : self .job_id , "job_state" : state })
94
113
95
114
96
- class DataprocClusterTrigger (BaseTrigger ):
115
+ class DataprocClusterTrigger (DataprocBaseTrigger ):
97
116
"""
98
- Trigger that periodically polls information from Dataproc API to verify status.
99
- Implementation leverages asynchronous transport.
117
+ DataprocClusterTrigger run on the trigger worker to perform create Build operation
118
+
119
+ :param cluster_name: The name of the cluster.
120
+ :param project_id: Google Cloud Project where the job is running
121
+ :param region: The Cloud Dataproc region in which to handle the request.
122
+ :param gcp_conn_id: Optional, the connection ID used to connect to Google Cloud Platform.
123
+ :param impersonation_chain: Optional service account to impersonate using short-term
124
+ credentials, or chained list of accounts required to get the access_token
125
+ of the last account in the list, which will be impersonated in the request.
126
+ If set as a string, the account must grant the originating account
127
+ the Service Account Token Creator IAM role.
128
+ If set as a sequence, the identities from the list must grant
129
+ Service Account Token Creator IAM role to the directly preceding identity, with first
130
+ account from the list granting this role to the originating account (templated).
131
+ :param polling_interval_seconds: polling period in seconds to check for the status
100
132
"""
101
133
102
- def __init__ (
103
- self ,
104
- cluster_name : str ,
105
- region : str ,
106
- project_id : str | None = None ,
107
- gcp_conn_id : str = "google_cloud_default" ,
108
- impersonation_chain : str | Sequence [str ] | None = None ,
109
- polling_interval_seconds : int = 10 ,
110
- ):
111
- super ().__init__ ()
112
- self .gcp_conn_id = gcp_conn_id
113
- self .impersonation_chain = impersonation_chain
134
+ def __init__ (self , cluster_name : str , ** kwargs ):
135
+ super ().__init__ (** kwargs )
114
136
self .cluster_name = cluster_name
115
- self .project_id = project_id
116
- self .region = region
117
- self .polling_interval_seconds = polling_interval_seconds
118
137
119
138
def serialize (self ) -> tuple [str , dict [str , Any ]]:
120
139
return (
@@ -130,9 +149,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
130
149
)
131
150
132
151
async def run (self ) -> AsyncIterator ["TriggerEvent" ]:
133
- hook = self ._get_hook ()
134
152
while True :
135
- cluster = await hook .get_cluster (
153
+ cluster = await self . get_async_hook () .get_cluster (
136
154
project_id = self .project_id , region = self .region , cluster_name = self .cluster_name
137
155
)
138
156
state = cluster .status .state
@@ -146,14 +164,8 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
146
164
await asyncio .sleep (self .polling_interval_seconds )
147
165
yield TriggerEvent ({"cluster_name" : self .cluster_name , "cluster_state" : state , "cluster" : cluster })
148
166
149
- def _get_hook (self ) -> DataprocAsyncHook :
150
- return DataprocAsyncHook (
151
- gcp_conn_id = self .gcp_conn_id ,
152
- impersonation_chain = self .impersonation_chain ,
153
- )
154
-
155
167
156
- class DataprocBatchTrigger (BaseTrigger ):
168
+ class DataprocBatchTrigger (DataprocBaseTrigger ):
157
169
"""
158
170
DataprocCreateBatchTrigger run on the trigger worker to perform create Build operation
159
171
@@ -172,22 +184,9 @@ class DataprocBatchTrigger(BaseTrigger):
172
184
:param polling_interval_seconds: polling period in seconds to check for the status
173
185
"""
174
186
175
- def __init__ (
176
- self ,
177
- batch_id : str ,
178
- region : str ,
179
- project_id : str | None ,
180
- gcp_conn_id : str = "google_cloud_default" ,
181
- impersonation_chain : str | Sequence [str ] | None = None ,
182
- polling_interval_seconds : float = 5.0 ,
183
- ):
184
- super ().__init__ ()
187
+ def __init__ (self , batch_id : str , ** kwargs ):
188
+ super ().__init__ (** kwargs )
185
189
self .batch_id = batch_id
186
- self .project_id = project_id
187
- self .region = region
188
- self .gcp_conn_id = gcp_conn_id
189
- self .impersonation_chain = impersonation_chain
190
- self .polling_interval_seconds = polling_interval_seconds
191
190
192
191
def serialize (self ) -> tuple [str , dict [str , Any ]]:
193
192
"""Serializes DataprocBatchTrigger arguments and classpath."""
@@ -204,13 +203,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
204
203
)
205
204
206
205
async def run (self ):
207
- hook = DataprocAsyncHook (
208
- gcp_conn_id = self .gcp_conn_id ,
209
- impersonation_chain = self .impersonation_chain ,
210
- )
211
-
212
206
while True :
213
- batch = await hook .get_batch (
207
+ batch = await self . get_async_hook () .get_batch (
214
208
project_id = self .project_id , region = self .region , batch_id = self .batch_id
215
209
)
216
210
state = batch .state
@@ -223,9 +217,9 @@ async def run(self):
223
217
yield TriggerEvent ({"batch_id" : self .batch_id , "batch_state" : state })
224
218
225
219
226
- class DataprocDeleteClusterTrigger (BaseTrigger ):
220
+ class DataprocDeleteClusterTrigger (DataprocBaseTrigger ):
227
221
"""
228
- Asynchronously checks the status of a cluster.
222
+ DataprocDeleteClusterTrigger run on the trigger worker to perform delete cluster operation .
229
223
230
224
:param cluster_name: The name of the cluster
231
225
:param end_time: Time in second left to check the cluster status
@@ -241,30 +235,20 @@ class DataprocDeleteClusterTrigger(BaseTrigger):
241
235
If set as a sequence, the identities from the list must grant
242
236
Service Account Token Creator IAM role to the directly preceding identity, with first
243
237
account from the list granting this role to the originating account.
244
- :param polling_interval : Time in seconds to sleep between checks of cluster status
238
+ :param polling_interval_seconds : Time in seconds to sleep between checks of cluster status
245
239
"""
246
240
247
241
def __init__ (
248
242
self ,
249
243
cluster_name : str ,
250
244
end_time : float ,
251
- project_id : str | None = None ,
252
- region : str | None = None ,
253
245
metadata : Sequence [tuple [str , str ]] = (),
254
- gcp_conn_id : str = "google_cloud_default" ,
255
- impersonation_chain : str | Sequence [str ] | None = None ,
256
- polling_interval : float = 5.0 ,
257
246
** kwargs : Any ,
258
247
):
259
248
super ().__init__ (** kwargs )
260
249
self .cluster_name = cluster_name
261
250
self .end_time = end_time
262
- self .project_id = project_id
263
- self .region = region
264
251
self .metadata = metadata
265
- self .gcp_conn_id = gcp_conn_id
266
- self .impersonation_chain = impersonation_chain
267
- self .polling_interval = polling_interval
268
252
269
253
def serialize (self ) -> tuple [str , dict [str , Any ]]:
270
254
"""Serializes DataprocDeleteClusterTrigger arguments and classpath."""
@@ -278,16 +262,15 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
278
262
"metadata" : self .metadata ,
279
263
"gcp_conn_id" : self .gcp_conn_id ,
280
264
"impersonation_chain" : self .impersonation_chain ,
281
- "polling_interval " : self .polling_interval ,
265
+ "polling_interval_seconds " : self .polling_interval_seconds ,
282
266
},
283
267
)
284
268
285
269
async def run (self ) -> AsyncIterator ["TriggerEvent" ]:
286
270
"""Wait until cluster is deleted completely"""
287
- hook = self ._get_hook ()
288
271
while self .end_time > time .time ():
289
272
try :
290
- cluster = await hook .get_cluster (
273
+ cluster = await self . get_async_hook () .get_cluster (
291
274
region = self .region , # type: ignore[arg-type]
292
275
cluster_name = self .cluster_name ,
293
276
project_id = self .project_id , # type: ignore[arg-type]
@@ -296,52 +279,26 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
296
279
self .log .info (
297
280
"Cluster status is %s. Sleeping for %s seconds." ,
298
281
cluster .status .state ,
299
- self .polling_interval ,
282
+ self .polling_interval_seconds ,
300
283
)
301
- await asyncio .sleep (self .polling_interval )
284
+ await asyncio .sleep (self .polling_interval_seconds )
302
285
except NotFound :
303
286
yield TriggerEvent ({"status" : "success" , "message" : "" })
304
287
except Exception as e :
305
288
yield TriggerEvent ({"status" : "error" , "message" : str (e )})
306
289
yield TriggerEvent ({"status" : "error" , "message" : "Timeout" })
307
290
308
- def _get_hook (self ) -> DataprocAsyncHook :
309
- return DataprocAsyncHook (
310
- gcp_conn_id = self .gcp_conn_id ,
311
- impersonation_chain = self .impersonation_chain ,
312
- )
313
-
314
291
315
- class DataprocWorkflowTrigger (BaseTrigger ):
292
+ class DataprocWorkflowTrigger (DataprocBaseTrigger ):
316
293
"""
317
294
Trigger that periodically polls information from Dataproc API to verify status.
318
295
Implementation leverages asynchronous transport.
319
296
"""
320
297
321
- def __init__ (
322
- self ,
323
- template_name : str ,
324
- name : str ,
325
- region : str ,
326
- project_id : str | None = None ,
327
- gcp_conn_id : str = "google_cloud_default" ,
328
- impersonation_chain : str | Sequence [str ] | None = None ,
329
- delegate_to : str | None = None ,
330
- polling_interval_seconds : int = 10 ,
331
- ):
332
- super ().__init__ ()
333
- self .gcp_conn_id = gcp_conn_id
298
+ def __init__ (self , template_name : str , name : str , ** kwargs : Any ):
299
+ super ().__init__ (** kwargs )
334
300
self .template_name = template_name
335
301
self .name = name
336
- self .impersonation_chain = impersonation_chain
337
- self .project_id = project_id
338
- self .region = region
339
- self .polling_interval_seconds = polling_interval_seconds
340
- self .delegate_to = delegate_to
341
- if delegate_to :
342
- warnings .warn (
343
- "'delegate_to' parameter is deprecated, please use 'impersonation_chain'" , DeprecationWarning
344
- )
345
302
346
303
def serialize (self ):
347
304
return (
@@ -359,7 +316,7 @@ def serialize(self):
359
316
)
360
317
361
318
async def run (self ) -> AsyncIterator ["TriggerEvent" ]:
362
- hook = self ._get_hook ()
319
+ hook = self .get_async_hook ()
363
320
while True :
364
321
try :
365
322
operation = await hook .get_operation (region = self .region , operation_name = self .name )
@@ -394,9 +351,3 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
394
351
"message" : str (e ),
395
352
}
396
353
)
397
-
398
- def _get_hook (self ) -> DataprocAsyncHook : # type: ignore[override]
399
- return DataprocAsyncHook (
400
- gcp_conn_id = self .gcp_conn_id ,
401
- impersonation_chain = self .impersonation_chain ,
402
- )
0 commit comments