Skip to content

Commit cfa4ecf

Browse files
author
Tobiasz KΔ™dzierski
authored
Add DataflowJobStatusSensor and support non-blocking execution of jobs (#11726)
1 parent cbd6daf commit cfa4ecf

File tree

8 files changed

+602
-89
lines changed

8 files changed

+602
-89
lines changed

β€Žairflow/providers/google/cloud/example_dags/example_dataflow.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
from urllib.parse import urlparse
2424

2525
from airflow import models
26+
from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
2627
from airflow.providers.google.cloud.operators.dataflow import (
2728
CheckJobRunning,
2829
DataflowCreateJavaJobOperator,
2930
DataflowCreatePythonJobOperator,
3031
DataflowTemplatedJobStartOperator,
3132
)
33+
from airflow.providers.google.cloud.sensors.dataflow import DataflowJobStatusSensor
3234
from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator
3335
from airflow.utils.dates import days_ago
3436

@@ -128,6 +130,38 @@
128130
py_system_site_packages=False,
129131
)
130132

133+
with models.DAG(
134+
"example_gcp_dataflow_native_python_async",
135+
default_args=default_args,
136+
start_date=days_ago(1),
137+
schedule_interval=None, # Override to match your needs
138+
tags=['example'],
139+
) as dag_native_python_async:
140+
start_python_job_async = DataflowCreatePythonJobOperator(
141+
task_id="start-python-job-async",
142+
py_file=GCS_PYTHON,
143+
py_options=[],
144+
job_name='{{task.task_id}}',
145+
options={
146+
'output': GCS_OUTPUT,
147+
},
148+
py_requirements=['apache-beam[gcp]==2.25.0'],
149+
py_interpreter='python3',
150+
py_system_site_packages=False,
151+
location='europe-west3',
152+
wait_until_finished=False,
153+
)
154+
155+
wait_for_python_job_async_done = DataflowJobStatusSensor(
156+
task_id="wait-for-python-job-async-done",
157+
job_id="{{task_instance.xcom_pull('start-python-job-async')['job_id']}}",
158+
expected_statuses={DataflowJobStatus.JOB_STATE_DONE},
159+
location='europe-west3',
160+
)
161+
162+
start_python_job_async >> wait_for_python_job_async_done
163+
164+
131165
with models.DAG(
132166
"example_gcp_dataflow_template",
133167
default_args=default_args,

β€Žairflow/providers/google/cloud/hooks/dataflow.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ class _DataflowJobsController(LoggingMixin):
149149
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
150150
instead of canceling.
151151
:param cancel_timeout: wait time in seconds for successful job canceling
152+
:param wait_until_finished: If True, wait for the end of pipeline execution before exiting. If False,
153+
it only submits job and check once is job not in terminal state.
154+
155+
The default behavior depends on the type of pipeline:
156+
157+
* for the streaming pipeline, wait for jobs to start,
158+
* for the batch pipeline, wait for the jobs to complete.
152159
"""
153160

154161
def __init__( # pylint: disable=too-many-arguments
@@ -163,6 +170,7 @@ def __init__( # pylint: disable=too-many-arguments
163170
multiple_jobs: bool = False,
164171
drain_pipeline: bool = False,
165172
cancel_timeout: Optional[int] = 5 * 60,
173+
wait_until_finished: Optional[bool] = None,
166174
) -> None:
167175

168176
super().__init__()
@@ -177,6 +185,8 @@ def __init__( # pylint: disable=too-many-arguments
177185
self._cancel_timeout = cancel_timeout
178186
self._jobs: Optional[List[dict]] = None
179187
self.drain_pipeline = drain_pipeline
188+
self._wait_until_finished = wait_until_finished
189+
self._jobs: Optional[List[dict]] = None
180190

181191
def is_job_running(self) -> bool:
182192
"""
@@ -203,7 +213,7 @@ def _get_current_jobs(self) -> List[dict]:
203213
:rtype: list
204214
"""
205215
if not self._multiple_jobs and self._job_id:
206-
return [self._fetch_job_by_id(self._job_id)]
216+
return [self.fetch_job_by_id(self._job_id)]
207217
elif self._job_name:
208218
jobs = self._fetch_jobs_by_prefix_name(self._job_name.lower())
209219
if len(jobs) == 1:
@@ -212,7 +222,15 @@ def _get_current_jobs(self) -> List[dict]:
212222
else:
213223
raise Exception("Missing both dataflow job ID and name.")
214224

215-
def _fetch_job_by_id(self, job_id: str) -> dict:
225+
def fetch_job_by_id(self, job_id: str) -> dict:
226+
"""
227+
Helper method to fetch the job with the specified Job ID.
228+
229+
:param job_id: Job ID to get.
230+
:type job_id: str
231+
:return: the Job
232+
:rtype: dict
233+
"""
216234
return (
217235
self._dataflow.projects()
218236
.locations()
@@ -278,19 +296,25 @@ def _check_dataflow_job_state(self, job) -> bool:
278296
:rtype: bool
279297
:raise: Exception
280298
"""
281-
if DataflowJobStatus.JOB_STATE_DONE == job["currentState"]:
299+
if self._wait_until_finished is None:
300+
wait_for_running = job['type'] == DataflowJobType.JOB_TYPE_STREAMING
301+
else:
302+
wait_for_running = not self._wait_until_finished
303+
304+
if job['currentState'] == DataflowJobStatus.JOB_STATE_DONE:
282305
return True
283-
elif DataflowJobStatus.JOB_STATE_FAILED == job["currentState"]:
284-
raise Exception("Google Cloud Dataflow job {} has failed.".format(job["name"]))
285-
elif DataflowJobStatus.JOB_STATE_CANCELLED == job["currentState"]:
286-
raise Exception("Google Cloud Dataflow job {} was cancelled.".format(job["name"]))
287-
elif (
288-
DataflowJobStatus.JOB_STATE_RUNNING == job["currentState"]
289-
and DataflowJobType.JOB_TYPE_STREAMING == job["type"]
290-
):
306+
elif job['currentState'] == DataflowJobStatus.JOB_STATE_FAILED:
307+
raise Exception("Google Cloud Dataflow job {} has failed.".format(job['name']))
308+
elif job['currentState'] == DataflowJobStatus.JOB_STATE_CANCELLED:
309+
raise Exception("Google Cloud Dataflow job {} was cancelled.".format(job['name']))
310+
elif job['currentState'] == DataflowJobStatus.JOB_STATE_DRAINED:
311+
raise Exception("Google Cloud Dataflow job {} was drained.".format(job['name']))
312+
elif job['currentState'] == DataflowJobStatus.JOB_STATE_UPDATED:
313+
raise Exception("Google Cloud Dataflow job {} was updated.".format(job['name']))
314+
elif job['currentState'] == DataflowJobStatus.JOB_STATE_RUNNING and wait_for_running:
291315
return True
292-
elif job["currentState"] in DataflowJobStatus.AWAITING_STATES:
293-
return False
316+
elif job['currentState'] in DataflowJobStatus.AWAITING_STATES:
317+
return self._wait_until_finished is False
294318
self.log.debug("Current job: %s", str(job))
295319
raise Exception(
296320
"Google Cloud Dataflow job {} was unknown state: {}".format(job["name"], job["currentState"])
@@ -487,10 +511,12 @@ def __init__(
487511
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
488512
drain_pipeline: bool = False,
489513
cancel_timeout: Optional[int] = 5 * 60,
514+
wait_until_finished: Optional[bool] = None,
490515
) -> None:
491516
self.poll_sleep = poll_sleep
492517
self.drain_pipeline = drain_pipeline
493518
self.cancel_timeout = cancel_timeout
519+
self.wait_until_finished = wait_until_finished
494520
super().__init__(
495521
gcp_conn_id=gcp_conn_id,
496522
delegate_to=delegate_to,
@@ -532,6 +558,7 @@ def _start_dataflow(
532558
multiple_jobs=multiple_jobs,
533559
drain_pipeline=self.drain_pipeline,
534560
cancel_timeout=self.cancel_timeout,
561+
wait_until_finished=self.wait_until_finished,
535562
)
536563
job_controller.wait_for_done()
537564

@@ -1047,3 +1074,30 @@ def start_sql_job(
10471074
jobs_controller.wait_for_done()
10481075

10491076
return jobs_controller.get_jobs(refresh=True)[0]
1077+
1078+
@GoogleBaseHook.fallback_to_default_project_id
1079+
def get_job(
1080+
self,
1081+
job_id: str,
1082+
project_id: str,
1083+
location: str = DEFAULT_DATAFLOW_LOCATION,
1084+
) -> dict:
1085+
"""
1086+
Gets the job with the specified Job ID.
1087+
1088+
:param job_id: Job ID to get.
1089+
:type job_id: str
1090+
:param project_id: Optional, the Google Cloud project ID in which to start a job.
1091+
If set to None or missing, the default project_id from the Google Cloud connection is used.
1092+
:type project_id:
1093+
:param location: The location of the Dataflow job (for example europe-west1). See:
1094+
https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
1095+
:return: the Job
1096+
:rtype: dict
1097+
"""
1098+
jobs_controller = _DataflowJobsController(
1099+
dataflow=self.get_conn(),
1100+
project_number=project_id,
1101+
location=location,
1102+
)
1103+
return jobs_controller.fetch_job_by_id(job_id)

0 commit comments

Comments
 (0)