Skip to content

Commit d757f6a

Browse files
authored
Fix BigQueryValueCheckOperator deferrable mode optimisation (#34018)
PR #31872 tried to optimise the deferrable mode in BigQueryValueCheckOperator. However for deciding on whether to defer it just checked the job status but did not actually verified the passed value to check for and returned a success prematurely. This PR adds on the missing logic with the optimisation to check and compare the pass value and tolerations. closes: #34010
1 parent 6ef80e8 commit d757f6a

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

β€Žairflow/providers/common/sql/operators/sql.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -827,10 +827,7 @@ def __init__(
827827
self.tol = tol if isinstance(tol, float) else None
828828
self.has_tolerance = self.tol is not None
829829

830-
def execute(self, context: Context):
831-
self.log.info("Executing SQL check: %s", self.sql)
832-
records = self.get_db_hook().get_first(self.sql)
833-
830+
def check_value(self, records):
834831
if not records:
835832
self._raise_exception(f"The following query returned zero rows: {self.sql}")
836833

@@ -862,6 +859,11 @@ def execute(self, context: Context):
862859
if not all(tests):
863860
self._raise_exception(error_msg)
864861

862+
def execute(self, context: Context):
863+
self.log.info("Executing SQL check: %s", self.sql)
864+
records = self.get_db_hook().get_first(self.sql)
865+
self.check_value(records)
866+
865867
def _to_float(self, records):
866868
return [float(record) for record in records]
867869

β€Žairflow/providers/google/cloud/operators/bigquery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,10 @@ def execute(self, context: Context) -> None: # type: ignore[override]
443443
method_name="execute_complete",
444444
)
445445
self._handle_job_error(job)
446+
# job.result() returns a RowIterator. Mypy expects an instance of SupportsNext[Any] for
447+
# the next() call which the RowIterator does not resemble to. Hence, ignore the arg-type error.
448+
records = next(job.result()) # type: ignore[arg-type]
449+
self.check_value(records)
446450
self.log.info("Current state of job %s is %s", job.job_id, job.state)
447451

448452
@staticmethod

β€Žtests/providers/google/cloud/operators/test_bigquery.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,11 +1919,11 @@ def test_bigquery_value_check_async(self, mock_hook, create_task_instance_of_ope
19191919
exc.value.trigger, BigQueryValueCheckTrigger
19201920
), "Trigger is not a BigQueryValueCheckTrigger"
19211921

1922-
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.execute")
19231922
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.defer")
1923+
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.check_value")
19241924
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
19251925
def test_bigquery_value_check_operator_async_finish_before_deferred(
1926-
self, mock_hook, mock_defer, mock_execute, create_task_instance_of_operator
1926+
self, mock_hook, mock_check_value, mock_defer, create_task_instance_of_operator
19271927
):
19281928
job_id = "123456"
19291929
hash_ = "hash"
@@ -1944,7 +1944,7 @@ def test_bigquery_value_check_operator_async_finish_before_deferred(
19441944

19451945
ti.task.execute(MagicMock())
19461946
assert not mock_defer.called
1947-
assert mock_execute.called
1947+
assert mock_check_value.called
19481948

19491949
@pytest.mark.parametrize(
19501950
"kwargs, expected",

0 commit comments

Comments
 (0)