Skip to content

Commit 8baf657

Browse files
author
Tobiasz KΔ™dzierski
authored
Fix regression in DataflowTemplatedJobStartOperator (#11167)
1 parent 422b61a commit 8baf657

File tree

4 files changed

+103
-2
lines changed

4 files changed

+103
-2
lines changed

β€Žairflow/providers/google/cloud/hooks/dataflow.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,13 +530,15 @@ def start_template_dataflow(
530530
append_job_name: bool = True,
531531
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
532532
location: str = DEFAULT_DATAFLOW_LOCATION,
533+
environment: Optional[Dict] = None,
533534
) -> Dict:
534535
"""
535536
Starts Dataflow template job.
536537
537538
:param job_name: The name of the job.
538539
:type job_name: str
539540
:param variables: Map of job runtime environment options.
541+
It will update environment argument if passed.
540542
541543
.. seealso::
542544
For more information on possible configurations, look at the API documentation
@@ -556,9 +558,48 @@ def start_template_dataflow(
556558
:type on_new_job_id_callback: callable
557559
:param location: Job location.
558560
:type location: str
561+
:type environment: Optional, Map of job runtime environment options.
562+
563+
.. seealso::
564+
For more information on possible configurations, look at the API documentation
565+
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
566+
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
567+
568+
:type environment: Optional[dict]
559569
"""
560570
name = self._build_dataflow_job_name(job_name, append_job_name)
561571

572+
environment = environment or {}
573+
# available keys for runtime environment are listed here:
574+
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
575+
environment_keys = [
576+
'numWorkers',
577+
'maxWorkers',
578+
'zone',
579+
'serviceAccountEmail',
580+
'tempLocation',
581+
'bypassTempDirValidation',
582+
'machineType',
583+
'additionalExperiments',
584+
'network',
585+
'subnetwork',
586+
'additionalUserLabels',
587+
'kmsKeyName',
588+
'ipConfiguration',
589+
'workerRegion',
590+
'workerZone',
591+
]
592+
593+
for key in variables:
594+
if key in environment_keys:
595+
if key in environment:
596+
self.log.warning(
597+
"'%s' parameter in 'variables' will override of "
598+
"the same one passed in 'environment'!",
599+
key,
600+
)
601+
environment.update({key: variables[key]})
602+
562603
service = self.get_conn()
563604
# pylint: disable=no-member
564605
request = (
@@ -569,7 +610,7 @@ def start_template_dataflow(
569610
projectId=project_id,
570611
location=location,
571612
gcsPath=dataflow_template,
572-
body={"jobName": name, "parameters": parameters, "environment": variables},
613+
body={"jobName": name, "parameters": parameters, "environment": environment},
573614
)
574615
)
575616
response = request.execute(num_retries=self.num_retries)

β€Žairflow/providers/google/cloud/operators/dataflow.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
281281
:param job_name: The 'jobName' to use when executing the DataFlow template
282282
(templated).
283283
:param options: Map of job runtime environment options.
284+
It will update environment argument if passed.
284285
285286
.. seealso::
286287
For more information on possible configurations, look at the API documentation
@@ -316,6 +317,13 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
316317
Service Account Token Creator IAM role to the directly preceding identity, with first
317318
account from the list granting this role to the originating account (templated).
318319
:type impersonation_chain: Union[str, Sequence[str]]
320+
:type environment: Optional, Map of job runtime environment options.
321+
322+
.. seealso::
323+
For more information on possible configurations, look at the API documentation
324+
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
325+
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
326+
:type environment: Optional[dict]
319327
320328
It's a good practice to define dataflow_* parameters in the default_args of the dag
321329
like the project, zone and staging location.
@@ -373,6 +381,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
373381
'location',
374382
'gcp_conn_id',
375383
'impersonation_chain',
384+
'environment',
376385
]
377386
ui_color = '#0273d4'
378387

@@ -391,6 +400,7 @@ def __init__( # pylint: disable=too-many-arguments
391400
delegate_to: Optional[str] = None,
392401
poll_sleep: int = 10,
393402
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
403+
environment: Optional[Dict] = None,
394404
**kwargs,
395405
) -> None:
396406
super().__init__(**kwargs)
@@ -407,6 +417,7 @@ def __init__( # pylint: disable=too-many-arguments
407417
self.job_id = None
408418
self.hook: Optional[DataflowHook] = None
409419
self.impersonation_chain = impersonation_chain
420+
self.environment = environment
410421

411422
def execute(self, context):
412423
self.hook = DataflowHook(
@@ -421,7 +432,6 @@ def set_current_job_id(job_id):
421432

422433
options = self.dataflow_default_options
423434
options.update(self.options)
424-
425435
job = self.hook.start_template_dataflow(
426436
job_name=self.job_name,
427437
variables=options,
@@ -430,6 +440,7 @@ def set_current_job_id(job_id):
430440
on_new_job_id_callback=set_current_job_id,
431441
project_id=self.project_id,
432442
location=self.location,
443+
environment=self.environment,
433444
)
434445

435446
return job

β€Žtests/providers/google/cloud/hooks/test_dataflow.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflow
746746
parameters=PARAMETERS,
747747
dataflow_template=TEST_TEMPLATE,
748748
project_id=TEST_PROJECT,
749+
environment={"numWorkers": 17},
749750
)
750751
body = {"jobName": mock.ANY, "parameters": PARAMETERS, "environment": RUNTIME_ENV}
751752
method.assert_called_once_with(
@@ -765,6 +766,52 @@ def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflow
765766
)
766767
mock_uuid.assert_called_once_with()
767768

769+
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
770+
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
771+
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
772+
def test_start_template_dataflow_update_runtime_env(self, mock_conn, mock_dataflowjob, mock_uuid):
773+
options_with_runtime_env = copy.deepcopy(RUNTIME_ENV)
774+
del options_with_runtime_env["numWorkers"]
775+
runtime_env = {"numWorkers": 17}
776+
expected_runtime_env = copy.deepcopy(RUNTIME_ENV)
777+
expected_runtime_env.update(runtime_env)
778+
779+
dataflowjob_instance = mock_dataflowjob.return_value
780+
dataflowjob_instance.wait_for_done.return_value = None
781+
# fmt: off
782+
method = (mock_conn.return_value
783+
.projects.return_value
784+
.locations.return_value
785+
.templates.return_value
786+
.launch)
787+
# fmt: on
788+
method.return_value.execute.return_value = {'job': {'id': TEST_JOB_ID}}
789+
self.dataflow_hook.start_template_dataflow( # pylint: disable=no-value-for-parameter
790+
job_name=JOB_NAME,
791+
variables=options_with_runtime_env,
792+
parameters=PARAMETERS,
793+
dataflow_template=TEST_TEMPLATE,
794+
project_id=TEST_PROJECT,
795+
environment=runtime_env,
796+
)
797+
body = {"jobName": mock.ANY, "parameters": PARAMETERS, "environment": expected_runtime_env}
798+
method.assert_called_once_with(
799+
projectId=TEST_PROJECT,
800+
location=DEFAULT_DATAFLOW_LOCATION,
801+
gcsPath=TEST_TEMPLATE,
802+
body=body,
803+
)
804+
mock_dataflowjob.assert_called_once_with(
805+
dataflow=mock_conn.return_value,
806+
job_id=TEST_JOB_ID,
807+
location=DEFAULT_DATAFLOW_LOCATION,
808+
name='test-dataflow-pipeline-{}'.format(MOCK_UUID),
809+
num_retries=5,
810+
poll_sleep=10,
811+
project_number=TEST_PROJECT,
812+
)
813+
mock_uuid.assert_called_once_with()
814+
768815
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
769816
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
770817
def test_cancel_job(self, mock_get_conn, jobs_controller):

β€Žtests/providers/google/cloud/operators/test_dataflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def setUp(self):
261261
dataflow_default_options={"EXTRA_OPTION": "TEST_A"},
262262
poll_sleep=POLL_SLEEP,
263263
location=TEST_LOCATION,
264+
environment={"maxWorkers": 2},
264265
)
265266

266267
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
@@ -287,4 +288,5 @@ def test_exec(self, dataflow_mock):
287288
on_new_job_id_callback=mock.ANY,
288289
project_id=None,
289290
location=TEST_LOCATION,
291+
environment={'maxWorkers': 2},
290292
)

0 commit comments

Comments
 (0)