Skip to content

Commit ec844ea

Browse files
bkossakowskaBeata Kossakowska
andauthored
Add deferrable mode to BigQueryTablePartitionExistenceSensor. (#29735)
Co-authored-by: Beata Kossakowska <bkossakowska@google.com>
1 parent 96dd371 commit ec844ea

File tree

7 files changed

+275
-2
lines changed

7 files changed

+275
-2
lines changed

β€Žairflow/providers/google/cloud/hooks/bigquery.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3048,6 +3048,24 @@ async def get_job_output(
30483048
job_query_response = await job_client.get_query_results(cast(Session, session))
30493049
return job_query_response
30503050

3051+
async def create_job_for_partition_get(
3052+
self,
3053+
dataset_id: str | None,
3054+
project_id: str | None = None,
3055+
):
3056+
"""Create a new job and get the job_id using gcloud-aio."""
3057+
async with ClientSession() as session:
3058+
self.log.info("Executing create_job..")
3059+
job_client = await self.get_job_instance(project_id, "", session)
3060+
3061+
query_request = {
3062+
"query": "SELECT partition_id "
3063+
f"FROM `{project_id}.{dataset_id}.INFORMATION_SCHEMA.PARTITIONS`",
3064+
"useLegacySql": False,
3065+
}
3066+
job_query_resp = await job_client.query(query_request, cast(Session, session))
3067+
return job_query_resp["jobReference"]["jobId"]
3068+
30513069
def get_records(self, query_results: dict[str, Any]) -> list[Any]:
30523070
"""
30533071
Given the output query response from gcloud-aio bigquery, convert the response to records.

β€Žairflow/providers/google/cloud/sensors/bigquery.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424

2525
from airflow.exceptions import AirflowException
2626
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
27-
from airflow.providers.google.cloud.triggers.bigquery import BigQueryTableExistenceTrigger
27+
from airflow.providers.google.cloud.triggers.bigquery import (
28+
BigQueryTableExistenceTrigger,
29+
BigQueryTablePartitionExistenceTrigger,
30+
)
2831
from airflow.sensors.base import BaseSensorOperator
2932

3033
if TYPE_CHECKING:
@@ -244,3 +247,65 @@ def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None
244247
return event["message"]
245248
raise AirflowException(event["message"])
246249
raise AirflowException("No event received in trigger callback")
250+
251+
252+
class BigQueryTableExistencePartitionAsyncSensor(BigQueryTablePartitionExistenceSensor):
253+
"""
254+
Checks for the existence of a partition within a table in Google BigQuery.
255+
256+
:param project_id: The Google cloud project in which to look for the table.
257+
The connection supplied to the hook must provide
258+
access to the specified project.
259+
:param dataset_id: The name of the dataset in which to look for the table.
260+
storage bucket.
261+
:param partition_id: The name of the partition to check the existence of.
262+
:param table_id: The name of the table to check the existence of.
263+
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
264+
:param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud.
265+
This parameter has been deprecated. You should pass the gcp_conn_id parameter instead.
266+
:param impersonation_chain: Optional service account to impersonate using short-term
267+
credentials, or chained list of accounts required to get the access_token
268+
of the last account in the list, which will be impersonated in the request.
269+
If set as a string, the account must grant the originating account
270+
the Service Account Token Creator IAM role.
271+
If set as a sequence, the identities from the list must grant
272+
Service Account Token Creator IAM role to the directly preceding identity, with first
273+
account from the list granting this role to the originating account (templated).
274+
:param poke_interval: The interval in seconds to wait between checks table existence.
275+
"""
276+
277+
def __init__(self, poke_interval: int = 5, **kwargs):
278+
super().__init__(**kwargs)
279+
self.poke_interval = poke_interval
280+
281+
def execute(self, context: Context) -> None:
282+
"""Airflow runs this method on the worker and defers using the trigger."""
283+
self.defer(
284+
timeout=timedelta(seconds=self.timeout),
285+
trigger=BigQueryTablePartitionExistenceTrigger(
286+
dataset_id=self.dataset_id,
287+
table_id=self.table_id,
288+
project_id=self.project_id,
289+
partition_id=self.partition_id,
290+
poll_interval=self.poke_interval,
291+
gcp_conn_id=self.gcp_conn_id,
292+
hook_params={
293+
"impersonation_chain": self.impersonation_chain,
294+
},
295+
),
296+
method_name="execute_complete",
297+
)
298+
299+
def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str:
300+
"""
301+
Callback for when the trigger fires - returns immediately.
302+
Relies on trigger to throw an exception, otherwise it assumes execution was
303+
successful.
304+
"""
305+
table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}"
306+
self.log.info('Sensor checks existence of partition: "%s" in table: %s', self.partition_id, table_uri)
307+
if event:
308+
if event["status"] == "success":
309+
return event["message"]
310+
raise AirflowException(event["message"])
311+
raise AirflowException("No event received in trigger callback")

β€Žairflow/providers/google/cloud/triggers/bigquery.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,3 +529,72 @@ async def _table_exists(
529529
if err.status == 404:
530530
return False
531531
raise err
532+
533+
534+
class BigQueryTablePartitionExistenceTrigger(BigQueryTableExistenceTrigger):
535+
"""
536+
Initialize the BigQuery Table Partition Existence Trigger with needed parameters
537+
:param partition_id: The name of the partition to check the existence of.
538+
:param project_id: Google Cloud Project where the job is running
539+
:param dataset_id: The dataset ID of the requested table.
540+
:param table_id: The table ID of the requested table.
541+
:param gcp_conn_id: Reference to google cloud connection id
542+
:param hook_params: params for hook
543+
:param poll_interval: polling period in seconds to check for the status
544+
"""
545+
546+
def __init__(self, partition_id: str, **kwargs):
547+
super().__init__(**kwargs)
548+
self.partition_id = partition_id
549+
550+
def serialize(self) -> tuple[str, dict[str, Any]]:
551+
"""Serializes BigQueryTablePartitionExistenceTrigger arguments and classpath."""
552+
return (
553+
"airflow.providers.google.cloud.triggers.bigquery.BigQueryTablePartitionExistenceTrigger",
554+
{
555+
"partition_id": self.partition_id,
556+
"dataset_id": self.dataset_id,
557+
"project_id": self.project_id,
558+
"table_id": self.table_id,
559+
"gcp_conn_id": self.gcp_conn_id,
560+
"poll_interval": self.poll_interval,
561+
"hook_params": self.hook_params,
562+
},
563+
)
564+
565+
async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
566+
"""Will run until the table exists in the Google Big Query."""
567+
hook = BigQueryAsyncHook(gcp_conn_id=self.gcp_conn_id)
568+
job_id = None
569+
while True:
570+
if job_id is not None:
571+
status = await hook.get_job_status(job_id=job_id, project_id=self.project_id)
572+
if status == "success":
573+
is_partition = await self._partition_exists(
574+
hook=hook, job_id=job_id, project_id=self.project_id
575+
)
576+
if is_partition:
577+
yield TriggerEvent(
578+
{
579+
"status": "success",
580+
"message": f"Partition: {self.partition_id} in table: {self.table_id}",
581+
}
582+
)
583+
job_id = None
584+
elif status == "error":
585+
yield TriggerEvent({"status": "error", "message": status})
586+
return
587+
self.log.info("Sleeping for %s seconds.", self.poll_interval)
588+
await asyncio.sleep(self.poll_interval)
589+
590+
else:
591+
job_id = await hook.create_job_for_partition_get(self.dataset_id, project_id=self.project_id)
592+
self.log.info("Sleeping for %s seconds.", self.poll_interval)
593+
await asyncio.sleep(self.poll_interval)
594+
595+
async def _partition_exists(self, hook: BigQueryAsyncHook, job_id: str | None, project_id: str):
596+
query_results = await hook.get_job_output(job_id=job_id, project_id=project_id)
597+
records = hook.get_records(query_results)
598+
if records:
599+
records = [row[0] for row in records]
600+
return self.partition_id in records

β€Ždocs/apache-airflow-providers-google/operators/cloud/bigquery.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,17 @@ To check that a table exists and has a partition you can use.
509509

510510
For DAY partitioned tables, the partition_id parameter is a string on the "%Y%m%d" format
511511

512+
Use the :class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistencePartitionAsyncSensor`
513+
(deferrable version) if you would like to free up the worker slots while the sensor is running.
514+
515+
:class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistencePartitionAsyncSensor`.
516+
517+
.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py
518+
:language: python
519+
:dedent: 4
520+
:start-after: [START howto_sensor_bigquery_table_partition_async]
521+
:end-before: [END howto_sensor_bigquery_table_partition_async]
522+
512523
Reference
513524
^^^^^^^^^
514525

β€Žtests/providers/google/cloud/sensors/test_bigquery.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@
2323
from airflow.exceptions import AirflowException, TaskDeferred
2424
from airflow.providers.google.cloud.sensors.bigquery import (
2525
BigQueryTableExistenceAsyncSensor,
26+
BigQueryTableExistencePartitionAsyncSensor,
2627
BigQueryTableExistenceSensor,
2728
BigQueryTablePartitionExistenceSensor,
2829
)
29-
from airflow.providers.google.cloud.triggers.bigquery import BigQueryTableExistenceTrigger
30+
from airflow.providers.google.cloud.triggers.bigquery import (
31+
BigQueryTableExistenceTrigger,
32+
BigQueryTablePartitionExistenceTrigger,
33+
)
3034

3135
TEST_PROJECT_ID = "test_project"
3236
TEST_DATASET_ID = "test_dataset"
@@ -156,3 +160,63 @@ def test_big_query_sensor_async_execute_complete_event_none(self):
156160
)
157161
with pytest.raises(AirflowException):
158162
task.execute_complete(context={}, event=None)
163+
164+
165+
class TestBigQueryTableExistencePartitionAsyncSensor(TestCase):
166+
def test_big_query_table_existence_partition_sensor_async(self):
167+
"""
168+
Asserts that a task is deferred and a BigQueryTablePartitionExistenceTrigger will be fired
169+
when the BigQueryTableExistencePartitionAsyncSensor is executed.
170+
"""
171+
task = BigQueryTableExistencePartitionAsyncSensor(
172+
task_id="test_task_id",
173+
project_id=TEST_PROJECT_ID,
174+
dataset_id=TEST_DATASET_ID,
175+
table_id=TEST_TABLE_ID,
176+
partition_id=TEST_PARTITION_ID,
177+
)
178+
with pytest.raises(TaskDeferred) as exc:
179+
task.execute(context={})
180+
assert isinstance(
181+
exc.value.trigger, BigQueryTablePartitionExistenceTrigger
182+
), "Trigger is not a BigQueryTablePartitionExistenceTrigger"
183+
184+
def test_big_query_table_existence_partition_sensor_async_execute_failure(self):
185+
"""Tests that an AirflowException is raised in case of error event"""
186+
task = BigQueryTableExistencePartitionAsyncSensor(
187+
task_id="test_task_id",
188+
project_id=TEST_PROJECT_ID,
189+
dataset_id=TEST_DATASET_ID,
190+
table_id=TEST_TABLE_ID,
191+
partition_id=TEST_PARTITION_ID,
192+
)
193+
with pytest.raises(AirflowException):
194+
task.execute_complete(context={}, event={"status": "error", "message": "test failure message"})
195+
196+
def test_big_query_table_existence_partition_sensor_async_execute_complete_event_none(self):
197+
"""Asserts that logging occurs as expected"""
198+
task = BigQueryTableExistencePartitionAsyncSensor(
199+
task_id="task-id",
200+
project_id=TEST_PROJECT_ID,
201+
dataset_id=TEST_DATASET_ID,
202+
table_id=TEST_TABLE_ID,
203+
partition_id=TEST_PARTITION_ID,
204+
)
205+
with pytest.raises(AirflowException, match="No event received in trigger callback"):
206+
task.execute_complete(context={}, event=None)
207+
208+
def test_big_query_table_existence_partition_sensor_async_execute_complete(self):
209+
"""Asserts that logging occurs as expected"""
210+
task = BigQueryTableExistencePartitionAsyncSensor(
211+
task_id="task-id",
212+
project_id=TEST_PROJECT_ID,
213+
dataset_id=TEST_DATASET_ID,
214+
table_id=TEST_TABLE_ID,
215+
partition_id=TEST_PARTITION_ID,
216+
)
217+
table_uri = f"{TEST_PROJECT_ID}:{TEST_DATASET_ID}.{TEST_TABLE_ID}"
218+
with mock.patch.object(task.log, "info") as mock_log_info:
219+
task.execute_complete(context={}, event={"status": "success", "message": "test"})
220+
mock_log_info.assert_called_with(
221+
'Sensor checks existence of partition: "%s" in table: %s', TEST_PARTITION_ID, table_uri
222+
)

β€Žtests/providers/google/cloud/triggers/test_bigquery.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
BigQueryInsertJobTrigger,
3434
BigQueryIntervalCheckTrigger,
3535
BigQueryTableExistenceTrigger,
36+
BigQueryTablePartitionExistenceTrigger,
3637
BigQueryValueCheckTrigger,
3738
)
3839
from airflow.triggers.base import TriggerEvent
@@ -59,6 +60,7 @@
5960
TEST_IGNORE_ZERO = True
6061
TEST_GCP_CONN_ID = "TEST_GCP_CONN_ID"
6162
TEST_HOOK_PARAMS: dict[str, Any] = {}
63+
TEST_PARTITION_ID = "1234"
6264

6365

6466
def test_bigquery_insert_job_op_trigger_serialization():
@@ -1043,3 +1045,36 @@ async def test_table_exists_raise_exception(mock_get_table_client):
10431045
)
10441046
with pytest.raises(ClientResponseError):
10451047
await trigger._table_exists(hook, TEST_DATASET_ID, TEST_TABLE_ID, TEST_GCP_PROJECT_ID)
1048+
1049+
1050+
class TestBigQueryTablePartitionExistenceTrigger:
1051+
def test_big_query_table_existence_partition_trigger_serialization_should_execute_successfully(self):
1052+
"""
1053+
Asserts that the BigQueryTablePartitionExistenceTrigger correctly serializes its arguments
1054+
and classpath.
1055+
"""
1056+
1057+
trigger = BigQueryTablePartitionExistenceTrigger(
1058+
dataset_id=TEST_DATASET_ID,
1059+
table_id=TEST_TABLE_ID,
1060+
project_id=TEST_GCP_PROJECT_ID,
1061+
partition_id=TEST_PARTITION_ID,
1062+
poll_interval=POLLING_PERIOD_SECONDS,
1063+
gcp_conn_id=TEST_GCP_CONN_ID,
1064+
hook_params={},
1065+
)
1066+
1067+
classpath, kwargs = trigger.serialize()
1068+
assert (
1069+
classpath
1070+
== "airflow.providers.google.cloud.triggers.bigquery.BigQueryTablePartitionExistenceTrigger"
1071+
)
1072+
assert kwargs == {
1073+
"dataset_id": TEST_DATASET_ID,
1074+
"project_id": TEST_GCP_PROJECT_ID,
1075+
"table_id": TEST_TABLE_ID,
1076+
"partition_id": TEST_PARTITION_ID,
1077+
"gcp_conn_id": TEST_GCP_CONN_ID,
1078+
"poll_interval": POLLING_PERIOD_SECONDS,
1079+
"hook_params": TEST_HOOK_PARAMS,
1080+
}

β€Žtests/system/providers/google/cloud/bigquery/example_bigquery_sensors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from airflow.providers.google.cloud.sensors.bigquery import (
3535
BigQueryTableExistenceAsyncSensor,
36+
BigQueryTableExistencePartitionAsyncSensor,
3637
BigQueryTableExistenceSensor,
3738
BigQueryTablePartitionExistenceSensor,
3839
)
@@ -117,6 +118,16 @@
117118
)
118119
# [END howto_sensor_bigquery_table_partition]
119120

121+
# [START howto_sensor_bigquery_table_partition_async]
122+
check_table_partition_exists_async: BaseSensorOperator = BigQueryTableExistencePartitionAsyncSensor(
123+
task_id="check_table_partition_exists_async",
124+
partition_id=PARTITION_NAME,
125+
project_id=PROJECT_ID,
126+
dataset_id=DATASET_NAME,
127+
table_id=TABLE_NAME,
128+
)
129+
# [END howto_sensor_bigquery_table_partition_async]
130+
120131
delete_dataset = BigQueryDeleteDatasetOperator(
121132
task_id="delete_dataset",
122133
dataset_id=DATASET_NAME,

0 commit comments

Comments
 (0)