Skip to content

Commit 72ddc94

Browse files
authored
Pass location using parmamter in Dataflow integration (#8382)
1 parent 912aa4b commit 72ddc94

File tree

6 files changed

+371
-82
lines changed

6 files changed

+371
-82
lines changed

β€Žairflow/providers/google/cloud/example_dags/example_dataflow.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from airflow.providers.google.cloud.operators.gcs import GCSToLocalOperator
3131
from airflow.utils.dates import days_ago
3232

33-
GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project')
3433
GCS_TMP = os.environ.get('GCP_DATAFLOW_GCS_TMP', 'gs://test-dataflow-example/temp/')
3534
GCS_STAGING = os.environ.get('GCP_DATAFLOW_GCS_STAGING', 'gs://test-dataflow-example/staging/')
3635
GCS_OUTPUT = os.environ.get('GCP_DATAFLOW_GCS_OUTPUT', 'gs://test-dataflow-example/output')
@@ -44,7 +43,6 @@
4443
default_args = {
4544
"start_date": days_ago(1),
4645
'dataflow_default_options': {
47-
'project': GCP_PROJECT_ID,
4846
'tempLocation': GCS_TMP,
4947
'stagingLocation': GCS_STAGING,
5048
}
@@ -68,6 +66,7 @@
6866
poll_sleep=10,
6967
job_class='org.apache.beam.examples.WordCount',
7068
check_if_running=CheckJobRunning.IgnoreJob,
69+
location='europe-west3'
7170
)
7271
# [END howto_operator_start_java_job]
7372

@@ -104,7 +103,8 @@
104103
'apache-beam[gcp]>=2.14.0'
105104
],
106105
py_interpreter='python3',
107-
py_system_site_packages=False
106+
py_system_site_packages=False,
107+
location='europe-west3'
108108
)
109109
# [END howto_operator_start_python_job]
110110

@@ -130,4 +130,5 @@
130130
'inputFile': "gs://dataflow-samples/shakespeare/kinglear.txt",
131131
'output': GCS_OUTPUT
132132
},
133+
location='europe-west3'
133134
)

β€Žairflow/providers/google/cloud/hooks/dataflow.py

Lines changed: 82 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import subprocess
2626
import time
2727
import uuid
28+
import warnings
2829
from copy import deepcopy
2930
from tempfile import TemporaryDirectory
3031
from typing import Any, Callable, Dict, List, Optional, TypeVar
@@ -49,36 +50,44 @@
4950
RT = TypeVar('RT') # pylint: disable=invalid-name
5051

5152

52-
def _fallback_to_project_id_from_variables(func: Callable[..., RT]) -> Callable[..., RT]:
53-
"""
54-
Decorator that provides fallback for Google Cloud Platform project id.
53+
def _fallback_variable_parameter(parameter_name, variable_key_name):
5554

56-
:param func: function to wrap
57-
:return: result of the function call
58-
"""
59-
@functools.wraps(func)
60-
def inner_wrapper(self: "DataflowHook", *args, **kwargs) -> RT:
61-
if args:
62-
raise AirflowException(
63-
"You must use keyword arguments in this methods rather than positional")
64-
65-
parameter_project_id = kwargs.get('project_id')
66-
variables_project_id = kwargs.get('variables', {}).get('project')
67-
68-
if parameter_project_id and variables_project_id:
69-
raise AirflowException(
70-
"The mutually exclusive parameter `project_id` and `project` key in `variables` parameters "
71-
"are both present. Please remove one."
72-
)
55+
def _wrapper(func: Callable[..., RT]) -> Callable[..., RT]:
56+
"""
57+
Decorator that provides fallback for location from `region` key in `variables` parameters.
58+
59+
:param func: function to wrap
60+
:return: result of the function call
61+
"""
62+
@functools.wraps(func)
63+
def inner_wrapper(self: "DataflowHook", *args, **kwargs) -> RT:
64+
if args:
65+
raise AirflowException(
66+
"You must use keyword arguments in this methods rather than positional")
67+
68+
parameter_location = kwargs.get(parameter_name)
69+
variables_location = kwargs.get('variables', {}).get(variable_key_name)
70+
71+
if parameter_location and variables_location:
72+
raise AirflowException(
73+
f"The mutually exclusive parameter `{parameter_name}` and `{variable_key_name}` key "
74+
f"in `variables` parameter are both present. Please remove one."
75+
)
76+
if parameter_location or variables_location:
77+
kwargs[parameter_name] = parameter_location or variables_location
78+
if variables_location:
79+
copy_variables = deepcopy(kwargs['variables'])
80+
del copy_variables[variable_key_name]
81+
kwargs['variables'] = copy_variables
82+
83+
return func(self, *args, **kwargs)
84+
return inner_wrapper
7385

74-
kwargs['project_id'] = parameter_project_id or variables_project_id
75-
if variables_project_id:
76-
copy_variables = deepcopy(kwargs['variables'])
77-
del copy_variables['project']
78-
kwargs['variables'] = copy_variables
86+
return _wrapper
7987

80-
return func(self, *args, **kwargs)
81-
return inner_wrapper
88+
89+
_fallback_to_location_from_variables = _fallback_variable_parameter('location', 'region')
90+
_fallback_to_project_id_from_variables = _fallback_variable_parameter('project_id', 'project')
8291

8392

8493
class DataflowJobStatus:
@@ -425,9 +434,9 @@ def _start_dataflow(
425434
label_formatter: Callable[[Dict], List[str]],
426435
project_id: str,
427436
multiple_jobs: bool = False,
428-
on_new_job_id_callback: Optional[Callable[[str], None]] = None
437+
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
438+
location: str = DEFAULT_DATAFLOW_LOCATION
429439
) -> None:
430-
variables = self._set_variables(variables)
431440
cmd = command_prefix + self._build_cmd(variables, label_formatter, project_id)
432441
runner = _DataflowRunner(
433442
cmd=cmd,
@@ -438,20 +447,15 @@ def _start_dataflow(
438447
dataflow=self.get_conn(),
439448
project_number=project_id,
440449
name=name,
441-
location=variables['region'],
450+
location=location,
442451
poll_sleep=self.poll_sleep,
443452
job_id=job_id,
444453
num_retries=self.num_retries,
445454
multiple_jobs=multiple_jobs
446455
)
447456
job_controller.wait_for_done()
448457

449-
@staticmethod
450-
def _set_variables(variables: Dict) -> Dict:
451-
if 'region' not in variables.keys():
452-
variables['region'] = DEFAULT_DATAFLOW_LOCATION
453-
return variables
454-
458+
@_fallback_to_location_from_variables
455459
@_fallback_to_project_id_from_variables
456460
@GoogleBaseHook.fallback_to_default_project_id
457461
def start_java_dataflow(
@@ -463,7 +467,8 @@ def start_java_dataflow(
463467
job_class: Optional[str] = None,
464468
append_job_name: bool = True,
465469
multiple_jobs: bool = False,
466-
on_new_job_id_callback: Optional[Callable[[str], None]] = None
470+
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
471+
location: str = DEFAULT_DATAFLOW_LOCATION
467472
) -> None:
468473
"""
469474
Starts Dataflow java job.
@@ -484,9 +489,12 @@ def start_java_dataflow(
484489
:type multiple_jobs: bool
485490
:param on_new_job_id_callback: Callback called when the job ID is known.
486491
:type on_new_job_id_callback: callable
492+
:param location: Job location.
493+
:type location: str
487494
"""
488495
name = self._build_dataflow_job_name(job_name, append_job_name)
489496
variables['jobName'] = name
497+
variables['region'] = location
490498

491499
def label_formatter(labels_dict):
492500
return ['--labels={}'.format(
@@ -501,9 +509,11 @@ def label_formatter(labels_dict):
501509
label_formatter=label_formatter,
502510
project_id=project_id,
503511
multiple_jobs=multiple_jobs,
504-
on_new_job_id_callback=on_new_job_id_callback
512+
on_new_job_id_callback=on_new_job_id_callback,
513+
location=location
505514
)
506515

516+
@_fallback_to_location_from_variables
507517
@_fallback_to_project_id_from_variables
508518
@GoogleBaseHook.fallback_to_default_project_id
509519
def start_template_dataflow(
@@ -514,7 +524,8 @@ def start_template_dataflow(
514524
dataflow_template: str,
515525
project_id: str,
516526
append_job_name: bool = True,
517-
on_new_job_id_callback: Optional[Callable[[str], None]] = None
527+
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
528+
location: str = DEFAULT_DATAFLOW_LOCATION
518529
) -> Dict:
519530
"""
520531
Starts Dataflow template job.
@@ -533,8 +544,9 @@ def start_template_dataflow(
533544
:type append_job_name: bool
534545
:param on_new_job_id_callback: Callback called when the job ID is known.
535546
:type on_new_job_id_callback: callable
547+
:param location: Job location.
548+
:type location: str
536549
"""
537-
variables = self._set_variables(variables)
538550
name = self._build_dataflow_job_name(job_name, append_job_name)
539551
# Builds RuntimeEnvironment from variables dictionary
540552
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
@@ -550,7 +562,7 @@ def start_template_dataflow(
550562
service = self.get_conn()
551563
request = service.projects().locations().templates().launch( # pylint: disable=no-member
552564
projectId=project_id,
553-
location=variables['region'],
565+
location=location,
554566
gcsPath=dataflow_template,
555567
body=body
556568
)
@@ -560,18 +572,18 @@ def start_template_dataflow(
560572
if on_new_job_id_callback:
561573
on_new_job_id_callback(job_id)
562574

563-
variables = self._set_variables(variables)
564575
jobs_controller = _DataflowJobsController(
565576
dataflow=self.get_conn(),
566577
project_number=project_id,
567578
name=name,
568579
job_id=job_id,
569-
location=variables['region'],
580+
location=location,
570581
poll_sleep=self.poll_sleep,
571582
num_retries=self.num_retries)
572583
jobs_controller.wait_for_done()
573584
return response["job"]
574585

586+
@_fallback_to_location_from_variables
575587
@_fallback_to_project_id_from_variables
576588
@GoogleBaseHook.fallback_to_default_project_id
577589
def start_python_dataflow( # pylint: disable=too-many-arguments
@@ -585,7 +597,8 @@ def start_python_dataflow( # pylint: disable=too-many-arguments
585597
py_requirements: Optional[List[str]] = None,
586598
py_system_site_packages: bool = False,
587599
append_job_name: bool = True,
588-
on_new_job_id_callback: Optional[Callable[[str], None]] = None
600+
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
601+
location: str = DEFAULT_DATAFLOW_LOCATION
589602
):
590603
"""
591604
Starts Dataflow job.
@@ -620,9 +633,12 @@ def start_python_dataflow( # pylint: disable=too-many-arguments
620633
If set to None or missing, the default project_id from the GCP connection is used.
621634
:param on_new_job_id_callback: Callback called when the job ID is known.
622635
:type on_new_job_id_callback: callable
636+
:param location: Job location.
637+
:type location: str
623638
"""
624639
name = self._build_dataflow_job_name(job_name, append_job_name)
625640
variables['job_name'] = name
641+
variables['region'] = location
626642

627643
def label_formatter(labels_dict):
628644
return ['--labels={}={}'.format(key, value)
@@ -644,7 +660,8 @@ def label_formatter(labels_dict):
644660
command_prefix=command_prefix,
645661
label_formatter=label_formatter,
646662
project_id=project_id,
647-
on_new_job_id_callback=on_new_job_id_callback
663+
on_new_job_id_callback=on_new_job_id_callback,
664+
location=location
648665
)
649666
else:
650667
command_prefix = [py_interpreter] + py_options + [dataflow]
@@ -655,7 +672,8 @@ def label_formatter(labels_dict):
655672
command_prefix=command_prefix,
656673
label_formatter=label_formatter,
657674
project_id=project_id,
658-
on_new_job_id_callback=on_new_job_id_callback
675+
on_new_job_id_callback=on_new_job_id_callback,
676+
location=location
659677
)
660678

661679
@staticmethod
@@ -700,27 +718,38 @@ def _build_cmd(variables: Dict, label_formatter: Callable, project_id: str) -> L
700718
command.append(f"--{attr}={value}")
701719
return command
702720

721+
@_fallback_to_location_from_variables
703722
@_fallback_to_project_id_from_variables
704723
@GoogleBaseHook.fallback_to_default_project_id
705-
def is_job_dataflow_running(self, name: str, variables: Dict, project_id: str) -> bool:
724+
def is_job_dataflow_running(
725+
self,
726+
name: str,
727+
project_id: str,
728+
location: str = DEFAULT_DATAFLOW_LOCATION,
729+
variables: Optional[Dict] = None
730+
) -> bool:
706731
"""
707732
Helper method to check if jos is still running in dataflow
708733
709734
:param name: The name of the job.
710735
:type name: str
711-
:param variables: Variables passed to the job.
712-
:type variables: dict
713736
:param project_id: Optional, the GCP project ID in which to start a job.
714737
If set to None or missing, the default project_id from the GCP connection is used.
738+
:type project_id: str
739+
:param location: Job location.
740+
:type location: str
715741
:return: True if job is running.
716742
:rtype: bool
717743
"""
718-
variables = self._set_variables(variables)
744+
if variables:
745+
warnings.warn(
746+
"The variables parameter has been deprecated. You should pass location using "
747+
"the location parameter.", DeprecationWarning, stacklevel=4)
719748
jobs_controller = _DataflowJobsController(
720749
dataflow=self.get_conn(),
721750
project_number=project_id,
722751
name=name,
723-
location=variables['region'],
752+
location=location,
724753
poll_sleep=self.poll_sleep
725754
)
726755
return jobs_controller.is_job_running()
@@ -731,7 +760,7 @@ def cancel_job(
731760
project_id: str,
732761
job_name: Optional[str] = None,
733762
job_id: Optional[str] = None,
734-
location: Optional[str] = None,
763+
location: str = DEFAULT_DATAFLOW_LOCATION,
735764
) -> None:
736765
"""
737766
Cancels the job with the specified name prefix or Job ID.
@@ -753,7 +782,7 @@ def cancel_job(
753782
project_number=project_id,
754783
name=job_name,
755784
job_id=job_id,
756-
location=location or DEFAULT_DATAFLOW_LOCATION,
785+
location=location,
757786
poll_sleep=self.poll_sleep
758787
)
759788
jobs_controller.cancel()

0 commit comments

Comments
 (0)