Skip to content

Commit e10aa6a

Browse files
authored
openlineage, bigquery: add openlineage method support for BigQueryExecuteQueryOperator (#31293)
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
1 parent af08392 commit e10aa6a

File tree

6 files changed

+428
-12
lines changed

6 files changed

+428
-12
lines changed

β€Žairflow/providers/google/cloud/hooks/bigquery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2245,7 +2245,7 @@ def run_query(
22452245
self.running_job_id = job.job_id
22462246
return job.job_id
22472247

2248-
def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False):
2248+
def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False) -> str:
22492249
if force_rerun:
22502250
hash_base = str(uuid.uuid4())
22512251
else:

β€Žairflow/providers/google/cloud/operators/bigquery.py

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,68 @@ def get_db_hook(self: BigQueryCheckOperator) -> BigQueryHook: # type:ignore[mis
133133
)
134134

135135

136+
class _BigQueryOpenLineageMixin:
137+
def get_openlineage_facets_on_complete(self, task_instance):
138+
"""
139+
Retrieve OpenLineage data for a COMPLETE BigQuery job.
140+
141+
This method retrieves statistics for the specified job_ids using the BigQueryDatasetsProvider.
142+
It calls BigQuery API, retrieving input and output dataset info from it, as well as run-level
143+
usage statistics.
144+
145+
Run facets should contain:
146+
- ExternalQueryRunFacet
147+
- BigQueryJobRunFacet
148+
149+
Job facets should contain:
150+
- SqlJobFacet if operator has self.sql
151+
152+
Input datasets should contain facets:
153+
- DataSourceDatasetFacet
154+
- SchemaDatasetFacet
155+
156+
Output datasets should contain facets:
157+
- DataSourceDatasetFacet
158+
- SchemaDatasetFacet
159+
- OutputStatisticsOutputDatasetFacet
160+
"""
161+
from openlineage.client.facet import SqlJobFacet
162+
from openlineage.common.provider.bigquery import BigQueryDatasetsProvider
163+
164+
from airflow.providers.openlineage.extractors import OperatorLineage
165+
from airflow.providers.openlineage.utils.utils import normalize_sql
166+
167+
if not self.job_id:
168+
return OperatorLineage()
169+
170+
client = self.hook.get_client(project_id=self.hook.project_id)
171+
job_ids = self.job_id
172+
if isinstance(self.job_id, str):
173+
job_ids = [self.job_id]
174+
inputs, outputs, run_facets = {}, {}, {}
175+
for job_id in job_ids:
176+
stats = BigQueryDatasetsProvider(client=client).get_facets(job_id=job_id)
177+
for input in stats.inputs:
178+
input = input.to_openlineage_dataset()
179+
inputs[input.name] = input
180+
if stats.output:
181+
output = stats.output.to_openlineage_dataset()
182+
outputs[output.name] = output
183+
for key, value in stats.run_facets.items():
184+
run_facets[key] = value
185+
186+
job_facets = {}
187+
if hasattr(self, "sql"):
188+
job_facets["sql"] = SqlJobFacet(query=normalize_sql(self.sql))
189+
190+
return OperatorLineage(
191+
inputs=list(inputs.values()),
192+
outputs=list(outputs.values()),
193+
run_facets=run_facets,
194+
job_facets=job_facets,
195+
)
196+
197+
136198
class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
137199
"""Performs checks against BigQuery.
138200
@@ -1153,6 +1215,7 @@ def __init__(
11531215
self.encryption_configuration = encryption_configuration
11541216
self.hook: BigQueryHook | None = None
11551217
self.impersonation_chain = impersonation_chain
1218+
self.job_id: str | list[str] | None = None
11561219

11571220
def execute(self, context: Context):
11581221
if self.hook is None:
@@ -1164,7 +1227,7 @@ def execute(self, context: Context):
11641227
impersonation_chain=self.impersonation_chain,
11651228
)
11661229
if isinstance(self.sql, str):
1167-
job_id: str | list[str] = self.hook.run_query(
1230+
self.job_id = self.hook.run_query(
11681231
sql=self.sql,
11691232
destination_dataset_table=self.destination_dataset_table,
11701233
write_disposition=self.write_disposition,
@@ -1184,7 +1247,7 @@ def execute(self, context: Context):
11841247
encryption_configuration=self.encryption_configuration,
11851248
)
11861249
elif isinstance(self.sql, Iterable):
1187-
job_id = [
1250+
self.job_id = [
11881251
self.hook.run_query(
11891252
sql=s,
11901253
destination_dataset_table=self.destination_dataset_table,
@@ -1210,9 +1273,9 @@ def execute(self, context: Context):
12101273
raise AirflowException(f"argument 'sql' of type {type(str)} is neither a string nor an iterable")
12111274
project_id = self.hook.project_id
12121275
if project_id:
1213-
job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location)
1276+
job_id_path = convert_job_id(job_id=self.job_id, project_id=project_id, location=self.location)
12141277
context["task_instance"].xcom_push(key="job_id_path", value=job_id_path)
1215-
return job_id
1278+
return self.job_id
12161279

12171280
def on_kill(self) -> None:
12181281
super().on_kill()
@@ -2562,7 +2625,7 @@ def execute(self, context: Context):
25622625
return table
25632626

25642627

2565-
class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
2628+
class BigQueryInsertJobOperator(GoogleCloudBaseOperator, _BigQueryOpenLineageMixin):
25662629
"""Execute a BigQuery job.
25672630
25682631
Waits for the job to complete and returns job id.
@@ -2663,6 +2726,13 @@ def __init__(
26632726
self.deferrable = deferrable
26642727
self.poll_interval = poll_interval
26652728

2729+
@property
2730+
def sql(self) -> str | None:
2731+
try:
2732+
return self.configuration["query"]["query"]
2733+
except KeyError:
2734+
return None
2735+
26662736
def prepare_template(self) -> None:
26672737
# If .json is passed then we have to read the file
26682738
if isinstance(self.configuration, str) and self.configuration.endswith(".json"):
@@ -2697,7 +2767,7 @@ def execute(self, context: Any):
26972767
)
26982768
self.hook = hook
26992769

2700-
job_id = hook.generate_job_id(
2770+
self.job_id = hook.generate_job_id(
27012771
job_id=self.job_id,
27022772
dag_id=self.dag_id,
27032773
task_id=self.task_id,
@@ -2708,13 +2778,13 @@ def execute(self, context: Any):
27082778

27092779
try:
27102780
self.log.info("Executing: %s'", self.configuration)
2711-
job: BigQueryJob | UnknownJob = self._submit_job(hook, job_id)
2781+
job: BigQueryJob | UnknownJob = self._submit_job(hook, self.job_id)
27122782
except Conflict:
27132783
# If the job already exists retrieve it
27142784
job = hook.get_job(
27152785
project_id=self.project_id,
27162786
location=self.location,
2717-
job_id=job_id,
2787+
job_id=self.job_id,
27182788
)
27192789
if job.state in self.reattach_states:
27202790
# We are reattaching to a job
@@ -2723,7 +2793,7 @@ def execute(self, context: Any):
27232793
else:
27242794
# Same job configuration so we need force_rerun
27252795
raise AirflowException(
2726-
f"Job with id: {job_id} already exists and is in {job.state} state. If you "
2796+
f"Job with id: {self.job_id} already exists and is in {job.state} state. If you "
27272797
f"want to force rerun it consider setting `force_rerun=True`."
27282798
f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
27292799
)
@@ -2757,7 +2827,9 @@ def execute(self, context: Any):
27572827
self.job_id = job.job_id
27582828
project_id = self.project_id or self.hook.project_id
27592829
if project_id:
2760-
job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location)
2830+
job_id_path = convert_job_id(
2831+
job_id=self.job_id, project_id=project_id, location=self.location # type: ignore[arg-type]
2832+
)
27612833
context["ti"].xcom_push(key="job_id_path", value=job_id_path)
27622834
# Wait for the job to complete
27632835
if not self.deferrable:

β€Žairflow/providers/openlineage/extractors/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ def extract(self) -> OperatorLineage | None:
8686
# OpenLineage methods are optional - if there's no method, return None
8787
try:
8888
return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore
89+
except ImportError:
90+
self.log.error(
91+
"OpenLineage provider method failed to import OpenLineage integration. "
92+
"This should not happen. Please report this bug to developers."
93+
)
94+
return None
8995
except AttributeError:
9096
return None
9197

β€Žairflow/providers/openlineage/utils/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import os
2424
from contextlib import suppress
2525
from functools import wraps
26-
from typing import TYPE_CHECKING, Any
26+
from typing import TYPE_CHECKING, Any, Iterable
2727
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
2828

2929
import attrs
@@ -414,3 +414,10 @@ def is_source_enabled() -> bool:
414414
def get_filtered_unknown_operator_keys(operator: BaseOperator) -> dict:
415415
not_required_keys = {"dag", "task_group"}
416416
return {attr: value for attr, value in operator.__dict__.items() if attr not in not_required_keys}
417+
418+
419+
def normalize_sql(sql: str | Iterable[str]):
420+
if isinstance(sql, str):
421+
sql = [stmt for stmt in sql.split(";") if stmt != ""]
422+
sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""]
423+
return ";\n".join(sql)

0 commit comments

Comments
 (0)