@@ -149,6 +149,13 @@ class _DataflowJobsController(LoggingMixin):
149
149
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
150
150
instead of canceling.
151
151
:param cancel_timeout: wait time in seconds for successful job canceling
152
+ :param wait_until_finished: If True, wait for the end of pipeline execution before exiting. If False,
153
+ it only submits job and check once is job not in terminal state.
154
+
155
+ The default behavior depends on the type of pipeline:
156
+
157
+ * for the streaming pipeline, wait for jobs to start,
158
+ * for the batch pipeline, wait for the jobs to complete.
152
159
"""
153
160
154
161
def __init__ ( # pylint: disable=too-many-arguments
@@ -163,6 +170,7 @@ def __init__( # pylint: disable=too-many-arguments
163
170
multiple_jobs : bool = False ,
164
171
drain_pipeline : bool = False ,
165
172
cancel_timeout : Optional [int ] = 5 * 60 ,
173
+ wait_until_finished : Optional [bool ] = None ,
166
174
) -> None :
167
175
168
176
super ().__init__ ()
@@ -177,6 +185,8 @@ def __init__( # pylint: disable=too-many-arguments
177
185
self ._cancel_timeout = cancel_timeout
178
186
self ._jobs : Optional [List [dict ]] = None
179
187
self .drain_pipeline = drain_pipeline
188
+ self ._wait_until_finished = wait_until_finished
189
+ self ._jobs : Optional [List [dict ]] = None
180
190
181
191
def is_job_running (self ) -> bool :
182
192
"""
@@ -203,7 +213,7 @@ def _get_current_jobs(self) -> List[dict]:
203
213
:rtype: list
204
214
"""
205
215
if not self ._multiple_jobs and self ._job_id :
206
- return [self ._fetch_job_by_id (self ._job_id )]
216
+ return [self .fetch_job_by_id (self ._job_id )]
207
217
elif self ._job_name :
208
218
jobs = self ._fetch_jobs_by_prefix_name (self ._job_name .lower ())
209
219
if len (jobs ) == 1 :
@@ -212,7 +222,15 @@ def _get_current_jobs(self) -> List[dict]:
212
222
else :
213
223
raise Exception ("Missing both dataflow job ID and name." )
214
224
215
- def _fetch_job_by_id (self , job_id : str ) -> dict :
225
+ def fetch_job_by_id (self , job_id : str ) -> dict :
226
+ """
227
+ Helper method to fetch the job with the specified Job ID.
228
+
229
+ :param job_id: Job ID to get.
230
+ :type job_id: str
231
+ :return: the Job
232
+ :rtype: dict
233
+ """
216
234
return (
217
235
self ._dataflow .projects ()
218
236
.locations ()
@@ -278,19 +296,25 @@ def _check_dataflow_job_state(self, job) -> bool:
278
296
:rtype: bool
279
297
:raise: Exception
280
298
"""
281
- if DataflowJobStatus .JOB_STATE_DONE == job ["currentState" ]:
299
+ if self ._wait_until_finished is None :
300
+ wait_for_running = job ['type' ] == DataflowJobType .JOB_TYPE_STREAMING
301
+ else :
302
+ wait_for_running = not self ._wait_until_finished
303
+
304
+ if job ['currentState' ] == DataflowJobStatus .JOB_STATE_DONE :
282
305
return True
283
- elif DataflowJobStatus .JOB_STATE_FAILED == job ["currentState" ]:
284
- raise Exception ("Google Cloud Dataflow job {} has failed." .format (job ["name" ]))
285
- elif DataflowJobStatus .JOB_STATE_CANCELLED == job ["currentState" ]:
286
- raise Exception ("Google Cloud Dataflow job {} was cancelled." .format (job ["name" ]))
287
- elif (
288
- DataflowJobStatus .JOB_STATE_RUNNING == job ["currentState" ]
289
- and DataflowJobType .JOB_TYPE_STREAMING == job ["type" ]
290
- ):
306
+ elif job ['currentState' ] == DataflowJobStatus .JOB_STATE_FAILED :
307
+ raise Exception ("Google Cloud Dataflow job {} has failed." .format (job ['name' ]))
308
+ elif job ['currentState' ] == DataflowJobStatus .JOB_STATE_CANCELLED :
309
+ raise Exception ("Google Cloud Dataflow job {} was cancelled." .format (job ['name' ]))
310
+ elif job ['currentState' ] == DataflowJobStatus .JOB_STATE_DRAINED :
311
+ raise Exception ("Google Cloud Dataflow job {} was drained." .format (job ['name' ]))
312
+ elif job ['currentState' ] == DataflowJobStatus .JOB_STATE_UPDATED :
313
+ raise Exception ("Google Cloud Dataflow job {} was updated." .format (job ['name' ]))
314
+ elif job ['currentState' ] == DataflowJobStatus .JOB_STATE_RUNNING and wait_for_running :
291
315
return True
292
- elif job [" currentState" ] in DataflowJobStatus .AWAITING_STATES :
293
- return False
316
+ elif job [' currentState' ] in DataflowJobStatus .AWAITING_STATES :
317
+ return self . _wait_until_finished is False
294
318
self .log .debug ("Current job: %s" , str (job ))
295
319
raise Exception (
296
320
"Google Cloud Dataflow job {} was unknown state: {}" .format (job ["name" ], job ["currentState" ])
@@ -487,10 +511,12 @@ def __init__(
487
511
impersonation_chain : Optional [Union [str , Sequence [str ]]] = None ,
488
512
drain_pipeline : bool = False ,
489
513
cancel_timeout : Optional [int ] = 5 * 60 ,
514
+ wait_until_finished : Optional [bool ] = None ,
490
515
) -> None :
491
516
self .poll_sleep = poll_sleep
492
517
self .drain_pipeline = drain_pipeline
493
518
self .cancel_timeout = cancel_timeout
519
+ self .wait_until_finished = wait_until_finished
494
520
super ().__init__ (
495
521
gcp_conn_id = gcp_conn_id ,
496
522
delegate_to = delegate_to ,
@@ -532,6 +558,7 @@ def _start_dataflow(
532
558
multiple_jobs = multiple_jobs ,
533
559
drain_pipeline = self .drain_pipeline ,
534
560
cancel_timeout = self .cancel_timeout ,
561
+ wait_until_finished = self .wait_until_finished ,
535
562
)
536
563
job_controller .wait_for_done ()
537
564
@@ -1047,3 +1074,30 @@ def start_sql_job(
1047
1074
jobs_controller .wait_for_done ()
1048
1075
1049
1076
return jobs_controller .get_jobs (refresh = True )[0 ]
1077
+
1078
+ @GoogleBaseHook .fallback_to_default_project_id
1079
+ def get_job (
1080
+ self ,
1081
+ job_id : str ,
1082
+ project_id : str ,
1083
+ location : str = DEFAULT_DATAFLOW_LOCATION ,
1084
+ ) -> dict :
1085
+ """
1086
+ Gets the job with the specified Job ID.
1087
+
1088
+ :param job_id: Job ID to get.
1089
+ :type job_id: str
1090
+ :param project_id: Optional, the Google Cloud project ID in which to start a job.
1091
+ If set to None or missing, the default project_id from the Google Cloud connection is used.
1092
+ :type project_id:
1093
+ :param location: The location of the Dataflow job (for example europe-west1). See:
1094
+ https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
1095
+ :return: the Job
1096
+ :rtype: dict
1097
+ """
1098
+ jobs_controller = _DataflowJobsController (
1099
+ dataflow = self .get_conn (),
1100
+ project_number = project_id ,
1101
+ location = location ,
1102
+ )
1103
+ return jobs_controller .fetch_job_by_id (job_id )
0 commit comments