Skip to content

Commit 5fcdd32

Browse files
authored
Add deferrable mode for Big Query Transfer operator (#27833)
1 parent 4b3a9ca commit 5fcdd32

File tree

7 files changed

+577
-13
lines changed

7 files changed

+577
-13
lines changed

β€Žairflow/providers/google/cloud/hooks/bigquery_dts.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
2525
from google.api_core.retry import Retry
26-
from google.cloud.bigquery_datatransfer_v1 import DataTransferServiceClient
26+
from google.cloud.bigquery_datatransfer_v1 import DataTransferServiceAsyncClient, DataTransferServiceClient
2727
from google.cloud.bigquery_datatransfer_v1.types import (
2828
StartManualTransferRunsResponse,
2929
TransferConfig,
@@ -32,7 +32,11 @@
3232
from googleapiclient.discovery import Resource
3333

3434
from airflow.providers.google.common.consts import CLIENT_INFO
35-
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
35+
from airflow.providers.google.common.hooks.base_google import (
36+
PROVIDE_PROJECT_ID,
37+
GoogleBaseAsyncHook,
38+
GoogleBaseHook,
39+
)
3640

3741

3842
def get_object_id(obj: dict) -> str:
@@ -263,3 +267,70 @@ def get_transfer_run(
263267
return client.get_transfer_run(
264268
request={"name": name}, retry=retry, timeout=timeout, metadata=metadata or ()
265269
)
270+
271+
272+
class AsyncBiqQueryDataTransferServiceHook(GoogleBaseAsyncHook):
273+
"""Hook of the BigQuery service to be used with async client of the Google library."""
274+
275+
sync_hook_class = BiqQueryDataTransferServiceHook
276+
277+
def __init__(
278+
self,
279+
gcp_conn_id: str = "google_cloud_default",
280+
delegate_to: str | None = None,
281+
location: str | None = None,
282+
impersonation_chain: str | Sequence[str] | None = None,
283+
):
284+
super().__init__(
285+
gcp_conn_id=gcp_conn_id,
286+
delegate_to=delegate_to,
287+
location=location,
288+
impersonation_chain=impersonation_chain,
289+
)
290+
self._conn: DataTransferServiceAsyncClient | None = None
291+
292+
async def _get_conn(self) -> DataTransferServiceAsyncClient:
293+
if not self._conn:
294+
credentials = (await self.get_sync_hook()).get_credentials()
295+
self._conn = DataTransferServiceAsyncClient(credentials=credentials, client_info=CLIENT_INFO)
296+
return self._conn
297+
298+
async def _get_project_id(self) -> str:
299+
sync_hook = await self.get_sync_hook()
300+
return sync_hook.project_id
301+
302+
async def get_transfer_run(
303+
self,
304+
config_id: str,
305+
run_id: str,
306+
project_id: str | None,
307+
retry: Retry | _MethodDefault = DEFAULT,
308+
timeout: float | None = None,
309+
metadata: Sequence[tuple[str, str]] = (),
310+
):
311+
"""
312+
Returns information about the particular transfer run.
313+
314+
:param run_id: ID of the transfer run.
315+
:param config_id: ID of transfer config to be used.
316+
:param project_id: The BigQuery project id where the transfer configuration should be
317+
created. If set to None or missing, the default project_id from the Google Cloud connection
318+
is used.
319+
:param retry: A retry object used to retry requests. If `None` is
320+
specified, requests will not be retried.
321+
:param timeout: The amount of time, in seconds, to wait for the request to
322+
complete. Note that if retry is specified, the timeout applies to each individual
323+
attempt.
324+
:param metadata: Additional metadata that is provided to the method.
325+
:return: An ``google.cloud.bigquery_datatransfer_v1.types.TransferRun`` instance.
326+
"""
327+
project_id = project_id or (await self._get_project_id())
328+
client = await self._get_conn()
329+
name = f"projects/{project_id}/transferConfigs/{config_id}/runs/{run_id}"
330+
transfer_run = await client.get_transfer_run(
331+
name=name,
332+
retry=retry,
333+
timeout=timeout,
334+
metadata=metadata,
335+
)
336+
return transfer_run

β€Žairflow/providers/google/cloud/operators/bigquery_dts.py

Lines changed: 99 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,24 @@
1818
"""This module contains Google BigQuery Data Transfer Service operators."""
1919
from __future__ import annotations
2020

21+
import time
2122
from typing import TYPE_CHECKING, Sequence
2223

2324
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
2425
from google.api_core.retry import Retry
25-
from google.cloud.bigquery_datatransfer_v1 import StartManualTransferRunsResponse, TransferConfig
26+
from google.cloud.bigquery_datatransfer_v1 import (
27+
StartManualTransferRunsResponse,
28+
TransferConfig,
29+
TransferRun,
30+
TransferState,
31+
)
2632

33+
from airflow import AirflowException
34+
from airflow.compat.functools import cached_property
2735
from airflow.models import BaseOperator
2836
from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook, get_object_id
2937
from airflow.providers.google.cloud.links.bigquery_dts import BigQueryDataTransferConfigLink
38+
from airflow.providers.google.cloud.triggers.bigquery_dts import BigQueryDataTransferRunTrigger
3039

3140
if TYPE_CHECKING:
3241
from airflow.utils.context import Context
@@ -224,7 +233,7 @@ class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator):
224233
must be of the same form as the protobuf message
225234
`~google.cloud.bigquery_datatransfer_v1.types.Timestamp`
226235
:param project_id: The BigQuery project id where the transfer configuration should be
227-
created. If set to None or missing, the default project_id from the Google Cloud connection is used.
236+
created.
228237
:param location: BigQuery Transfer Service location for regional transfers.
229238
:param retry: A retry object used to retry requests. If `None` is
230239
specified, requests will not be retried.
@@ -241,6 +250,7 @@ class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator):
241250
If set as a sequence, the identities from the list must grant
242251
Service Account Token Creator IAM role to the directly preceding identity, with first
243252
account from the list granting this role to the originating account (templated).
253+
:param deferrable: Run operator in the deferrable mode.
244254
"""
245255

246256
template_fields: Sequence[str] = (
@@ -266,6 +276,7 @@ def __init__(
266276
metadata: Sequence[tuple[str, str]] = (),
267277
gcp_conn_id="google_cloud_default",
268278
impersonation_chain: str | Sequence[str] | None = None,
279+
deferrable: bool = False,
269280
**kwargs,
270281
) -> None:
271282
super().__init__(**kwargs)
@@ -279,13 +290,20 @@ def __init__(
279290
self.metadata = metadata
280291
self.gcp_conn_id = gcp_conn_id
281292
self.impersonation_chain = impersonation_chain
293+
self.deferrable = deferrable
282294

283-
def execute(self, context: Context):
295+
@cached_property
296+
def hook(self) -> BiqQueryDataTransferServiceHook:
284297
hook = BiqQueryDataTransferServiceHook(
285-
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, location=self.location
298+
gcp_conn_id=self.gcp_conn_id,
299+
impersonation_chain=self.impersonation_chain,
300+
location=self.location,
286301
)
302+
return hook
303+
304+
def execute(self, context: Context):
287305
self.log.info("Submitting manual transfer for %s", self.transfer_config_id)
288-
response = hook.start_manual_transfer_runs(
306+
response = self.hook.start_manual_transfer_runs(
289307
transfer_config_id=self.transfer_config_id,
290308
requested_time_range=self.requested_time_range,
291309
requested_run_time=self.requested_run_time,
@@ -307,5 +325,79 @@ def execute(self, context: Context):
307325
result = StartManualTransferRunsResponse.to_dict(response)
308326
run_id = get_object_id(result["runs"][0])
309327
self.xcom_push(context, key="run_id", value=run_id)
310-
self.log.info("Transfer run %s submitted successfully.", run_id)
311-
return result
328+
329+
if not self.deferrable:
330+
result = self._wait_for_transfer_to_be_done(
331+
run_id=run_id,
332+
transfer_config_id=transfer_config["config_id"],
333+
)
334+
self.log.info("Transfer run %s submitted successfully.", run_id)
335+
return result
336+
337+
self.defer(
338+
trigger=BigQueryDataTransferRunTrigger(
339+
project_id=self.project_id,
340+
config_id=transfer_config["config_id"],
341+
run_id=run_id,
342+
gcp_conn_id=self.gcp_conn_id,
343+
location=self.location,
344+
impersonation_chain=self.impersonation_chain,
345+
),
346+
method_name="execute_completed",
347+
)
348+
349+
def _wait_for_transfer_to_be_done(self, run_id: str, transfer_config_id: str, interval: int = 10):
350+
if interval < 0:
351+
raise ValueError("Interval must be > 0")
352+
353+
while True:
354+
transfer_run: TransferRun = self.hook.get_transfer_run(
355+
run_id=run_id,
356+
transfer_config_id=transfer_config_id,
357+
project_id=self.project_id,
358+
retry=self.retry,
359+
timeout=self.timeout,
360+
metadata=self.metadata,
361+
)
362+
state = transfer_run.state
363+
364+
if self._job_is_done(state):
365+
if state == TransferState.FAILED or state == TransferState.CANCELLED:
366+
raise AirflowException(f"Transfer run was finished with {state} status.")
367+
368+
result = TransferRun.to_dict(transfer_run)
369+
return result
370+
371+
self.log.info("Transfer run is still working, waiting for %s seconds...", interval)
372+
self.log.info("Transfer run status: %s", state)
373+
time.sleep(interval)
374+
375+
@staticmethod
376+
def _job_is_done(state: TransferState) -> bool:
377+
finished_job_statuses = [
378+
state.SUCCEEDED,
379+
state.CANCELLED,
380+
state.FAILED,
381+
]
382+
383+
return state in finished_job_statuses
384+
385+
def execute_completed(self, context: Context, event: dict):
386+
"""Method to be executed after invoked trigger in defer method finishes its job."""
387+
if event["status"] == "failed" or event["status"] == "cancelled":
388+
self.log.error("Trigger finished its work with status: %s.", event["status"])
389+
raise AirflowException(event["message"])
390+
391+
transfer_run: TransferRun = self.hook.get_transfer_run(
392+
project_id=self.project_id,
393+
run_id=event["run_id"],
394+
transfer_config_id=event["config_id"],
395+
)
396+
397+
self.log.info(
398+
"%s finished with message: %s",
399+
event["run_id"],
400+
event["message"],
401+
)
402+
403+
return TransferRun.to_dict(transfer_run)

0 commit comments

Comments
 (0)