@@ -133,6 +133,68 @@ def get_db_hook(self: BigQueryCheckOperator) -> BigQueryHook: # type:ignore[mis
133
133
)
134
134
135
135
136
+ class _BigQueryOpenLineageMixin :
137
+ def get_openlineage_facets_on_complete (self , task_instance ):
138
+ """
139
+ Retrieve OpenLineage data for a COMPLETE BigQuery job.
140
+
141
+ This method retrieves statistics for the specified job_ids using the BigQueryDatasetsProvider.
142
+ It calls BigQuery API, retrieving input and output dataset info from it, as well as run-level
143
+ usage statistics.
144
+
145
+ Run facets should contain:
146
+ - ExternalQueryRunFacet
147
+ - BigQueryJobRunFacet
148
+
149
+ Job facets should contain:
150
+ - SqlJobFacet if operator has self.sql
151
+
152
+ Input datasets should contain facets:
153
+ - DataSourceDatasetFacet
154
+ - SchemaDatasetFacet
155
+
156
+ Output datasets should contain facets:
157
+ - DataSourceDatasetFacet
158
+ - SchemaDatasetFacet
159
+ - OutputStatisticsOutputDatasetFacet
160
+ """
161
+ from openlineage .client .facet import SqlJobFacet
162
+ from openlineage .common .provider .bigquery import BigQueryDatasetsProvider
163
+
164
+ from airflow .providers .openlineage .extractors import OperatorLineage
165
+ from airflow .providers .openlineage .utils .utils import normalize_sql
166
+
167
+ if not self .job_id :
168
+ return OperatorLineage ()
169
+
170
+ client = self .hook .get_client (project_id = self .hook .project_id )
171
+ job_ids = self .job_id
172
+ if isinstance (self .job_id , str ):
173
+ job_ids = [self .job_id ]
174
+ inputs , outputs , run_facets = {}, {}, {}
175
+ for job_id in job_ids :
176
+ stats = BigQueryDatasetsProvider (client = client ).get_facets (job_id = job_id )
177
+ for input in stats .inputs :
178
+ input = input .to_openlineage_dataset ()
179
+ inputs [input .name ] = input
180
+ if stats .output :
181
+ output = stats .output .to_openlineage_dataset ()
182
+ outputs [output .name ] = output
183
+ for key , value in stats .run_facets .items ():
184
+ run_facets [key ] = value
185
+
186
+ job_facets = {}
187
+ if hasattr (self , "sql" ):
188
+ job_facets ["sql" ] = SqlJobFacet (query = normalize_sql (self .sql ))
189
+
190
+ return OperatorLineage (
191
+ inputs = list (inputs .values ()),
192
+ outputs = list (outputs .values ()),
193
+ run_facets = run_facets ,
194
+ job_facets = job_facets ,
195
+ )
196
+
197
+
136
198
class BigQueryCheckOperator (_BigQueryDbHookMixin , SQLCheckOperator ):
137
199
"""Performs checks against BigQuery.
138
200
@@ -1153,6 +1215,7 @@ def __init__(
1153
1215
self .encryption_configuration = encryption_configuration
1154
1216
self .hook : BigQueryHook | None = None
1155
1217
self .impersonation_chain = impersonation_chain
1218
+ self .job_id : str | list [str ] | None = None
1156
1219
1157
1220
def execute (self , context : Context ):
1158
1221
if self .hook is None :
@@ -1164,7 +1227,7 @@ def execute(self, context: Context):
1164
1227
impersonation_chain = self .impersonation_chain ,
1165
1228
)
1166
1229
if isinstance (self .sql , str ):
1167
- job_id : str | list [ str ] = self .hook .run_query (
1230
+ self . job_id = self .hook .run_query (
1168
1231
sql = self .sql ,
1169
1232
destination_dataset_table = self .destination_dataset_table ,
1170
1233
write_disposition = self .write_disposition ,
@@ -1184,7 +1247,7 @@ def execute(self, context: Context):
1184
1247
encryption_configuration = self .encryption_configuration ,
1185
1248
)
1186
1249
elif isinstance (self .sql , Iterable ):
1187
- job_id = [
1250
+ self . job_id = [
1188
1251
self .hook .run_query (
1189
1252
sql = s ,
1190
1253
destination_dataset_table = self .destination_dataset_table ,
@@ -1210,9 +1273,9 @@ def execute(self, context: Context):
1210
1273
raise AirflowException (f"argument 'sql' of type { type (str )} is neither a string nor an iterable" )
1211
1274
project_id = self .hook .project_id
1212
1275
if project_id :
1213
- job_id_path = convert_job_id (job_id = job_id , project_id = project_id , location = self .location )
1276
+ job_id_path = convert_job_id (job_id = self . job_id , project_id = project_id , location = self .location )
1214
1277
context ["task_instance" ].xcom_push (key = "job_id_path" , value = job_id_path )
1215
- return job_id
1278
+ return self . job_id
1216
1279
1217
1280
def on_kill (self ) -> None :
1218
1281
super ().on_kill ()
@@ -2562,7 +2625,7 @@ def execute(self, context: Context):
2562
2625
return table
2563
2626
2564
2627
2565
- class BigQueryInsertJobOperator (GoogleCloudBaseOperator ):
2628
+ class BigQueryInsertJobOperator (GoogleCloudBaseOperator , _BigQueryOpenLineageMixin ):
2566
2629
"""Execute a BigQuery job.
2567
2630
2568
2631
Waits for the job to complete and returns job id.
@@ -2663,6 +2726,13 @@ def __init__(
2663
2726
self .deferrable = deferrable
2664
2727
self .poll_interval = poll_interval
2665
2728
2729
+ @property
2730
+ def sql (self ) -> str | None :
2731
+ try :
2732
+ return self .configuration ["query" ]["query" ]
2733
+ except KeyError :
2734
+ return None
2735
+
2666
2736
def prepare_template (self ) -> None :
2667
2737
# If .json is passed then we have to read the file
2668
2738
if isinstance (self .configuration , str ) and self .configuration .endswith (".json" ):
@@ -2697,7 +2767,7 @@ def execute(self, context: Any):
2697
2767
)
2698
2768
self .hook = hook
2699
2769
2700
- job_id = hook .generate_job_id (
2770
+ self . job_id = hook .generate_job_id (
2701
2771
job_id = self .job_id ,
2702
2772
dag_id = self .dag_id ,
2703
2773
task_id = self .task_id ,
@@ -2708,13 +2778,13 @@ def execute(self, context: Any):
2708
2778
2709
2779
try :
2710
2780
self .log .info ("Executing: %s'" , self .configuration )
2711
- job : BigQueryJob | UnknownJob = self ._submit_job (hook , job_id )
2781
+ job : BigQueryJob | UnknownJob = self ._submit_job (hook , self . job_id )
2712
2782
except Conflict :
2713
2783
# If the job already exists retrieve it
2714
2784
job = hook .get_job (
2715
2785
project_id = self .project_id ,
2716
2786
location = self .location ,
2717
- job_id = job_id ,
2787
+ job_id = self . job_id ,
2718
2788
)
2719
2789
if job .state in self .reattach_states :
2720
2790
# We are reattaching to a job
@@ -2723,7 +2793,7 @@ def execute(self, context: Any):
2723
2793
else :
2724
2794
# Same job configuration so we need force_rerun
2725
2795
raise AirflowException (
2726
- f"Job with id: { job_id } already exists and is in { job .state } state. If you "
2796
+ f"Job with id: { self . job_id } already exists and is in { job .state } state. If you "
2727
2797
f"want to force rerun it consider setting `force_rerun=True`."
2728
2798
f"Or, if you want to reattach in this scenario add { job .state } to `reattach_states`"
2729
2799
)
@@ -2757,7 +2827,9 @@ def execute(self, context: Any):
2757
2827
self .job_id = job .job_id
2758
2828
project_id = self .project_id or self .hook .project_id
2759
2829
if project_id :
2760
- job_id_path = convert_job_id (job_id = job_id , project_id = project_id , location = self .location )
2830
+ job_id_path = convert_job_id (
2831
+ job_id = self .job_id , project_id = project_id , location = self .location # type: ignore[arg-type]
2832
+ )
2761
2833
context ["ti" ].xcom_push (key = "job_id_path" , value = job_id_path )
2762
2834
# Wait for the job to complete
2763
2835
if not self .deferrable :
0 commit comments