18
18
"""This module contains Google BigQuery Data Transfer Service operators."""
19
19
from __future__ import annotations
20
20
21
+ import time
21
22
from typing import TYPE_CHECKING , Sequence
22
23
23
24
from google .api_core .gapic_v1 .method import DEFAULT , _MethodDefault
24
25
from google .api_core .retry import Retry
25
- from google .cloud .bigquery_datatransfer_v1 import StartManualTransferRunsResponse , TransferConfig
26
+ from google .cloud .bigquery_datatransfer_v1 import (
27
+ StartManualTransferRunsResponse ,
28
+ TransferConfig ,
29
+ TransferRun ,
30
+ TransferState ,
31
+ )
26
32
33
+ from airflow import AirflowException
34
+ from airflow .compat .functools import cached_property
27
35
from airflow .models import BaseOperator
28
36
from airflow .providers .google .cloud .hooks .bigquery_dts import BiqQueryDataTransferServiceHook , get_object_id
29
37
from airflow .providers .google .cloud .links .bigquery_dts import BigQueryDataTransferConfigLink
38
+ from airflow .providers .google .cloud .triggers .bigquery_dts import BigQueryDataTransferRunTrigger
30
39
31
40
if TYPE_CHECKING :
32
41
from airflow .utils .context import Context
@@ -224,7 +233,7 @@ class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator):
224
233
must be of the same form as the protobuf message
225
234
`~google.cloud.bigquery_datatransfer_v1.types.Timestamp`
226
235
:param project_id: The BigQuery project id where the transfer configuration should be
227
- created. If set to None or missing, the default project_id from the Google Cloud connection is used.
236
+ created.
228
237
:param location: BigQuery Transfer Service location for regional transfers.
229
238
:param retry: A retry object used to retry requests. If `None` is
230
239
specified, requests will not be retried.
@@ -241,6 +250,7 @@ class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator):
241
250
If set as a sequence, the identities from the list must grant
242
251
Service Account Token Creator IAM role to the directly preceding identity, with first
243
252
account from the list granting this role to the originating account (templated).
253
+ :param deferrable: Run operator in the deferrable mode.
244
254
"""
245
255
246
256
template_fields : Sequence [str ] = (
@@ -266,6 +276,7 @@ def __init__(
266
276
metadata : Sequence [tuple [str , str ]] = (),
267
277
gcp_conn_id = "google_cloud_default" ,
268
278
impersonation_chain : str | Sequence [str ] | None = None ,
279
+ deferrable : bool = False ,
269
280
** kwargs ,
270
281
) -> None :
271
282
super ().__init__ (** kwargs )
@@ -279,13 +290,20 @@ def __init__(
279
290
self .metadata = metadata
280
291
self .gcp_conn_id = gcp_conn_id
281
292
self .impersonation_chain = impersonation_chain
293
+ self .deferrable = deferrable
282
294
283
- def execute (self , context : Context ):
295
+ @cached_property
296
+ def hook (self ) -> BiqQueryDataTransferServiceHook :
284
297
hook = BiqQueryDataTransferServiceHook (
285
- gcp_conn_id = self .gcp_conn_id , impersonation_chain = self .impersonation_chain , location = self .location
298
+ gcp_conn_id = self .gcp_conn_id ,
299
+ impersonation_chain = self .impersonation_chain ,
300
+ location = self .location ,
286
301
)
302
+ return hook
303
+
304
+ def execute (self , context : Context ):
287
305
self .log .info ("Submitting manual transfer for %s" , self .transfer_config_id )
288
- response = hook .start_manual_transfer_runs (
306
+ response = self . hook .start_manual_transfer_runs (
289
307
transfer_config_id = self .transfer_config_id ,
290
308
requested_time_range = self .requested_time_range ,
291
309
requested_run_time = self .requested_run_time ,
@@ -307,5 +325,79 @@ def execute(self, context: Context):
307
325
result = StartManualTransferRunsResponse .to_dict (response )
308
326
run_id = get_object_id (result ["runs" ][0 ])
309
327
self .xcom_push (context , key = "run_id" , value = run_id )
310
- self .log .info ("Transfer run %s submitted successfully." , run_id )
311
- return result
328
+
329
+ if not self .deferrable :
330
+ result = self ._wait_for_transfer_to_be_done (
331
+ run_id = run_id ,
332
+ transfer_config_id = transfer_config ["config_id" ],
333
+ )
334
+ self .log .info ("Transfer run %s submitted successfully." , run_id )
335
+ return result
336
+
337
+ self .defer (
338
+ trigger = BigQueryDataTransferRunTrigger (
339
+ project_id = self .project_id ,
340
+ config_id = transfer_config ["config_id" ],
341
+ run_id = run_id ,
342
+ gcp_conn_id = self .gcp_conn_id ,
343
+ location = self .location ,
344
+ impersonation_chain = self .impersonation_chain ,
345
+ ),
346
+ method_name = "execute_completed" ,
347
+ )
348
+
349
+ def _wait_for_transfer_to_be_done (self , run_id : str , transfer_config_id : str , interval : int = 10 ):
350
+ if interval < 0 :
351
+ raise ValueError ("Interval must be > 0" )
352
+
353
+ while True :
354
+ transfer_run : TransferRun = self .hook .get_transfer_run (
355
+ run_id = run_id ,
356
+ transfer_config_id = transfer_config_id ,
357
+ project_id = self .project_id ,
358
+ retry = self .retry ,
359
+ timeout = self .timeout ,
360
+ metadata = self .metadata ,
361
+ )
362
+ state = transfer_run .state
363
+
364
+ if self ._job_is_done (state ):
365
+ if state == TransferState .FAILED or state == TransferState .CANCELLED :
366
+ raise AirflowException (f"Transfer run was finished with { state } status." )
367
+
368
+ result = TransferRun .to_dict (transfer_run )
369
+ return result
370
+
371
+ self .log .info ("Transfer run is still working, waiting for %s seconds..." , interval )
372
+ self .log .info ("Transfer run status: %s" , state )
373
+ time .sleep (interval )
374
+
375
+ @staticmethod
376
+ def _job_is_done (state : TransferState ) -> bool :
377
+ finished_job_statuses = [
378
+ state .SUCCEEDED ,
379
+ state .CANCELLED ,
380
+ state .FAILED ,
381
+ ]
382
+
383
+ return state in finished_job_statuses
384
+
385
+ def execute_completed (self , context : Context , event : dict ):
386
+ """Method to be executed after invoked trigger in defer method finishes its job."""
387
+ if event ["status" ] == "failed" or event ["status" ] == "cancelled" :
388
+ self .log .error ("Trigger finished its work with status: %s." , event ["status" ])
389
+ raise AirflowException (event ["message" ])
390
+
391
+ transfer_run : TransferRun = self .hook .get_transfer_run (
392
+ project_id = self .project_id ,
393
+ run_id = event ["run_id" ],
394
+ transfer_config_id = event ["config_id" ],
395
+ )
396
+
397
+ self .log .info (
398
+ "%s finished with message: %s" ,
399
+ event ["run_id" ],
400
+ event ["message" ],
401
+ )
402
+
403
+ return TransferRun .to_dict (transfer_run )
0 commit comments