Skip to content

Commit b0a40bb

Browse files
authored
Optimize deferred execution mode (#30946)
1 parent f3e82b2 commit b0a40bb

File tree

2 files changed

+84
-36
lines changed

2 files changed

+84
-36
lines changed

β€Žairflow/providers/google/cloud/sensors/bigquery.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,24 @@ def poke(self, context: Context) -> bool:
110110

111111
def execute(self, context: Context) -> None:
112112
"""Airflow runs this method on the worker and defers using the trigger."""
113-
self.defer(
114-
timeout=timedelta(seconds=self.timeout),
115-
trigger=BigQueryTableExistenceTrigger(
116-
dataset_id=self.dataset_id,
117-
table_id=self.table_id,
118-
project_id=self.project_id,
119-
poll_interval=self.poke_interval,
120-
gcp_conn_id=self.gcp_conn_id,
121-
hook_params={
122-
"impersonation_chain": self.impersonation_chain,
123-
},
124-
),
125-
method_name="execute_complete",
126-
)
113+
if not self.deferrable:
114+
super().execute(context)
115+
else:
116+
if not self.poke(context=context):
117+
self.defer(
118+
timeout=timedelta(seconds=self.timeout),
119+
trigger=BigQueryTableExistenceTrigger(
120+
dataset_id=self.dataset_id,
121+
table_id=self.table_id,
122+
project_id=self.project_id,
123+
poll_interval=self.poke_interval,
124+
gcp_conn_id=self.gcp_conn_id,
125+
hook_params={
126+
"impersonation_chain": self.impersonation_chain,
127+
},
128+
),
129+
method_name="execute_complete",
130+
)
127131

128132
def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str:
129133
"""
@@ -218,21 +222,22 @@ def execute(self, context: Context) -> None:
218222
if not self.deferrable:
219223
super().execute(context)
220224
else:
221-
self.defer(
222-
timeout=timedelta(seconds=self.timeout),
223-
trigger=BigQueryTablePartitionExistenceTrigger(
224-
dataset_id=self.dataset_id,
225-
table_id=self.table_id,
226-
project_id=self.project_id,
227-
partition_id=self.partition_id,
228-
poll_interval=self.poke_interval,
229-
gcp_conn_id=self.gcp_conn_id,
230-
hook_params={
231-
"impersonation_chain": self.impersonation_chain,
232-
},
233-
),
234-
method_name="execute_complete",
235-
)
225+
if not self.poke(context=context):
226+
self.defer(
227+
timeout=timedelta(seconds=self.timeout),
228+
trigger=BigQueryTablePartitionExistenceTrigger(
229+
dataset_id=self.dataset_id,
230+
table_id=self.table_id,
231+
project_id=self.project_id,
232+
partition_id=self.partition_id,
233+
poll_interval=self.poke_interval,
234+
gcp_conn_id=self.gcp_conn_id,
235+
hook_params={
236+
"impersonation_chain": self.impersonation_chain,
237+
},
238+
),
239+
method_name="execute_complete",
240+
)
236241

237242
def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str:
238243
"""

β€Žtests/providers/google/cloud/sensors/test_bigquery.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,24 @@ def test_passing_arguments_to_hook(self, mock_hook):
6464
project_id=TEST_PROJECT_ID, dataset_id=TEST_DATASET_ID, table_id=TEST_TABLE_ID
6565
)
6666

67-
def test_execute_defered(self):
67+
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
68+
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor.defer")
69+
def test_table_existence_sensor_finish_before_deferred(self, mock_defer, mock_hook):
70+
task = BigQueryTableExistenceSensor(
71+
task_id="task-id",
72+
project_id=TEST_PROJECT_ID,
73+
dataset_id=TEST_DATASET_ID,
74+
table_id=TEST_TABLE_ID,
75+
gcp_conn_id=TEST_GCP_CONN_ID,
76+
impersonation_chain=TEST_IMPERSONATION_CHAIN,
77+
deferrable=True,
78+
)
79+
mock_hook.return_value.table_exists.return_value = True
80+
task.execute(mock.MagicMock())
81+
assert not mock_defer.called
82+
83+
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
84+
def test_execute_deferred(self, mock_hook):
6885
"""
6986
Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired
7087
when the BigQueryTableExistenceAsyncSensor is executed.
@@ -76,13 +93,14 @@ def test_execute_defered(self):
7693
table_id=TEST_TABLE_ID,
7794
deferrable=True,
7895
)
96+
mock_hook.return_value.table_exists.return_value = False
7997
with pytest.raises(TaskDeferred) as exc:
80-
task.execute(context={})
98+
task.execute(mock.MagicMock())
8199
assert isinstance(
82100
exc.value.trigger, BigQueryTableExistenceTrigger
83101
), "Trigger is not a BigQueryTableExistenceTrigger"
84102

85-
def test_excute_defered_failure(self):
103+
def test_execute_deferred_failure(self):
86104
"""Tests that an AirflowException is raised in case of error event"""
87105
task = BigQueryTableExistenceSensor(
88106
task_id="task-id",
@@ -148,7 +166,9 @@ def test_passing_arguments_to_hook(self, mock_hook):
148166
partition_id=TEST_PARTITION_ID,
149167
)
150168

151-
def test_execute_with_deferrable_mode(self):
169+
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
170+
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryTablePartitionExistenceSensor.defer")
171+
def test_table_partition_existence_sensor_finish_before_deferred(self, mock_defer, mock_hook):
152172
"""
153173
Asserts that a task is deferred and a BigQueryTablePartitionExistenceTrigger will be fired
154174
when the BigQueryTablePartitionExistenceSensor is executed and deferrable is set to True.
@@ -161,6 +181,25 @@ def test_execute_with_deferrable_mode(self):
161181
partition_id=TEST_PARTITION_ID,
162182
deferrable=True,
163183
)
184+
mock_hook.return_value.table_partition_exists.return_value = True
185+
task.execute(mock.MagicMock())
186+
assert not mock_defer.called
187+
188+
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
189+
def test_execute_with_deferrable_mode(self, mock_hook):
190+
"""
191+
Asserts that a task is deferred and a BigQueryTablePartitionExistenceTrigger will be fired
192+
when the BigQueryTablePartitionExistenceSensor is executed and deferrable is set to True.
193+
"""
194+
task = BigQueryTablePartitionExistenceSensor(
195+
task_id="test_task_id",
196+
project_id=TEST_PROJECT_ID,
197+
dataset_id=TEST_DATASET_ID,
198+
table_id=TEST_TABLE_ID,
199+
partition_id=TEST_PARTITION_ID,
200+
deferrable=True,
201+
)
202+
mock_hook.return_value.table_partition_exists.return_value = False
164203
with pytest.raises(TaskDeferred) as exc:
165204
task.execute(context={})
166205
assert isinstance(
@@ -228,7 +267,8 @@ class TestBigQueryTableExistenceAsyncSensor:
228267
"set `deferrable` attribute to `True` instead"
229268
)
230269

231-
def test_big_query_table_existence_sensor_async(self):
270+
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
271+
def test_big_query_table_existence_sensor_async(self, mock_hook):
232272
"""
233273
Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired
234274
when the BigQueryTableExistenceAsyncSensor is executed.
@@ -240,6 +280,7 @@ def test_big_query_table_existence_sensor_async(self):
240280
dataset_id=TEST_DATASET_ID,
241281
table_id=TEST_TABLE_ID,
242282
)
283+
mock_hook.return_value.table_exists.return_value = False
243284
with pytest.raises(TaskDeferred) as exc:
244285
task.execute(context={})
245286
assert isinstance(
@@ -293,7 +334,8 @@ class TestBigQueryTableExistencePartitionAsyncSensor:
293334
"set `deferrable` attribute to `True` instead"
294335
)
295336

296-
def test_big_query_table_existence_partition_sensor_async(self):
337+
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
338+
def test_big_query_table_existence_partition_sensor_async(self, mock_hook):
297339
"""
298340
Asserts that a task is deferred and a BigQueryTablePartitionExistenceTrigger will be fired
299341
when the BigQueryTableExistencePartitionAsyncSensor is executed.
@@ -306,8 +348,9 @@ def test_big_query_table_existence_partition_sensor_async(self):
306348
table_id=TEST_TABLE_ID,
307349
partition_id=TEST_PARTITION_ID,
308350
)
351+
mock_hook.return_value.table_partition_exists.return_value = False
309352
with pytest.raises(TaskDeferred) as exc:
310-
task.execute(context={})
353+
task.execute(mock.MagicMock())
311354
assert isinstance(
312355
exc.value.trigger, BigQueryTablePartitionExistenceTrigger
313356
), "Trigger is not a BigQueryTablePartitionExistenceTrigger"

0 commit comments

Comments
 (0)