25
25
import subprocess
26
26
import time
27
27
import uuid
28
+ import warnings
28
29
from copy import deepcopy
29
30
from tempfile import TemporaryDirectory
30
31
from typing import Any , Callable , Dict , List , Optional , TypeVar
49
50
RT = TypeVar ('RT' ) # pylint: disable=invalid-name
50
51
51
52
52
- def _fallback_to_project_id_from_variables (func : Callable [..., RT ]) -> Callable [..., RT ]:
53
- """
54
- Decorator that provides fallback for Google Cloud Platform project id.
53
+ def _fallback_variable_parameter (parameter_name , variable_key_name ):
55
54
56
- :param func: function to wrap
57
- :return: result of the function call
58
- """
59
- @functools .wraps (func )
60
- def inner_wrapper (self : "DataflowHook" , * args , ** kwargs ) -> RT :
61
- if args :
62
- raise AirflowException (
63
- "You must use keyword arguments in this methods rather than positional" )
64
-
65
- parameter_project_id = kwargs .get ('project_id' )
66
- variables_project_id = kwargs .get ('variables' , {}).get ('project' )
67
-
68
- if parameter_project_id and variables_project_id :
69
- raise AirflowException (
70
- "The mutually exclusive parameter `project_id` and `project` key in `variables` parameters "
71
- "are both present. Please remove one."
72
- )
55
+ def _wrapper (func : Callable [..., RT ]) -> Callable [..., RT ]:
56
+ """
57
+ Decorator that provides fallback for location from `region` key in `variables` parameters.
58
+
59
+ :param func: function to wrap
60
+ :return: result of the function call
61
+ """
62
+ @functools .wraps (func )
63
+ def inner_wrapper (self : "DataflowHook" , * args , ** kwargs ) -> RT :
64
+ if args :
65
+ raise AirflowException (
66
+ "You must use keyword arguments in this methods rather than positional" )
67
+
68
+ parameter_location = kwargs .get (parameter_name )
69
+ variables_location = kwargs .get ('variables' , {}).get (variable_key_name )
70
+
71
+ if parameter_location and variables_location :
72
+ raise AirflowException (
73
+ f"The mutually exclusive parameter `{ parameter_name } ` and `{ variable_key_name } ` key "
74
+ f"in `variables` parameter are both present. Please remove one."
75
+ )
76
+ if parameter_location or variables_location :
77
+ kwargs [parameter_name ] = parameter_location or variables_location
78
+ if variables_location :
79
+ copy_variables = deepcopy (kwargs ['variables' ])
80
+ del copy_variables [variable_key_name ]
81
+ kwargs ['variables' ] = copy_variables
82
+
83
+ return func (self , * args , ** kwargs )
84
+ return inner_wrapper
73
85
74
- kwargs ['project_id' ] = parameter_project_id or variables_project_id
75
- if variables_project_id :
76
- copy_variables = deepcopy (kwargs ['variables' ])
77
- del copy_variables ['project' ]
78
- kwargs ['variables' ] = copy_variables
86
+ return _wrapper
79
87
80
- return func (self , * args , ** kwargs )
81
- return inner_wrapper
88
+
89
+ _fallback_to_location_from_variables = _fallback_variable_parameter ('location' , 'region' )
90
+ _fallback_to_project_id_from_variables = _fallback_variable_parameter ('project_id' , 'project' )
82
91
83
92
84
93
class DataflowJobStatus :
@@ -425,9 +434,9 @@ def _start_dataflow(
425
434
label_formatter : Callable [[Dict ], List [str ]],
426
435
project_id : str ,
427
436
multiple_jobs : bool = False ,
428
- on_new_job_id_callback : Optional [Callable [[str ], None ]] = None
437
+ on_new_job_id_callback : Optional [Callable [[str ], None ]] = None ,
438
+ location : str = DEFAULT_DATAFLOW_LOCATION
429
439
) -> None :
430
- variables = self ._set_variables (variables )
431
440
cmd = command_prefix + self ._build_cmd (variables , label_formatter , project_id )
432
441
runner = _DataflowRunner (
433
442
cmd = cmd ,
@@ -438,20 +447,15 @@ def _start_dataflow(
438
447
dataflow = self .get_conn (),
439
448
project_number = project_id ,
440
449
name = name ,
441
- location = variables [ 'region' ] ,
450
+ location = location ,
442
451
poll_sleep = self .poll_sleep ,
443
452
job_id = job_id ,
444
453
num_retries = self .num_retries ,
445
454
multiple_jobs = multiple_jobs
446
455
)
447
456
job_controller .wait_for_done ()
448
457
449
- @staticmethod
450
- def _set_variables (variables : Dict ) -> Dict :
451
- if 'region' not in variables .keys ():
452
- variables ['region' ] = DEFAULT_DATAFLOW_LOCATION
453
- return variables
454
-
458
+ @_fallback_to_location_from_variables
455
459
@_fallback_to_project_id_from_variables
456
460
@GoogleBaseHook .fallback_to_default_project_id
457
461
def start_java_dataflow (
@@ -463,7 +467,8 @@ def start_java_dataflow(
463
467
job_class : Optional [str ] = None ,
464
468
append_job_name : bool = True ,
465
469
multiple_jobs : bool = False ,
466
- on_new_job_id_callback : Optional [Callable [[str ], None ]] = None
470
+ on_new_job_id_callback : Optional [Callable [[str ], None ]] = None ,
471
+ location : str = DEFAULT_DATAFLOW_LOCATION
467
472
) -> None :
468
473
"""
469
474
Starts Dataflow java job.
@@ -484,9 +489,12 @@ def start_java_dataflow(
484
489
:type multiple_jobs: bool
485
490
:param on_new_job_id_callback: Callback called when the job ID is known.
486
491
:type on_new_job_id_callback: callable
492
+ :param location: Job location.
493
+ :type location: str
487
494
"""
488
495
name = self ._build_dataflow_job_name (job_name , append_job_name )
489
496
variables ['jobName' ] = name
497
+ variables ['region' ] = location
490
498
491
499
def label_formatter (labels_dict ):
492
500
return ['--labels={}' .format (
@@ -501,9 +509,11 @@ def label_formatter(labels_dict):
501
509
label_formatter = label_formatter ,
502
510
project_id = project_id ,
503
511
multiple_jobs = multiple_jobs ,
504
- on_new_job_id_callback = on_new_job_id_callback
512
+ on_new_job_id_callback = on_new_job_id_callback ,
513
+ location = location
505
514
)
506
515
516
+ @_fallback_to_location_from_variables
507
517
@_fallback_to_project_id_from_variables
508
518
@GoogleBaseHook .fallback_to_default_project_id
509
519
def start_template_dataflow (
@@ -514,7 +524,8 @@ def start_template_dataflow(
514
524
dataflow_template : str ,
515
525
project_id : str ,
516
526
append_job_name : bool = True ,
517
- on_new_job_id_callback : Optional [Callable [[str ], None ]] = None
527
+ on_new_job_id_callback : Optional [Callable [[str ], None ]] = None ,
528
+ location : str = DEFAULT_DATAFLOW_LOCATION
518
529
) -> Dict :
519
530
"""
520
531
Starts Dataflow template job.
@@ -533,8 +544,9 @@ def start_template_dataflow(
533
544
:type append_job_name: bool
534
545
:param on_new_job_id_callback: Callback called when the job ID is known.
535
546
:type on_new_job_id_callback: callable
547
+ :param location: Job location.
548
+ :type location: str
536
549
"""
537
- variables = self ._set_variables (variables )
538
550
name = self ._build_dataflow_job_name (job_name , append_job_name )
539
551
# Builds RuntimeEnvironment from variables dictionary
540
552
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
@@ -550,7 +562,7 @@ def start_template_dataflow(
550
562
service = self .get_conn ()
551
563
request = service .projects ().locations ().templates ().launch ( # pylint: disable=no-member
552
564
projectId = project_id ,
553
- location = variables [ 'region' ] ,
565
+ location = location ,
554
566
gcsPath = dataflow_template ,
555
567
body = body
556
568
)
@@ -560,18 +572,18 @@ def start_template_dataflow(
560
572
if on_new_job_id_callback :
561
573
on_new_job_id_callback (job_id )
562
574
563
- variables = self ._set_variables (variables )
564
575
jobs_controller = _DataflowJobsController (
565
576
dataflow = self .get_conn (),
566
577
project_number = project_id ,
567
578
name = name ,
568
579
job_id = job_id ,
569
- location = variables [ 'region' ] ,
580
+ location = location ,
570
581
poll_sleep = self .poll_sleep ,
571
582
num_retries = self .num_retries )
572
583
jobs_controller .wait_for_done ()
573
584
return response ["job" ]
574
585
586
+ @_fallback_to_location_from_variables
575
587
@_fallback_to_project_id_from_variables
576
588
@GoogleBaseHook .fallback_to_default_project_id
577
589
def start_python_dataflow ( # pylint: disable=too-many-arguments
@@ -585,7 +597,8 @@ def start_python_dataflow( # pylint: disable=too-many-arguments
585
597
py_requirements : Optional [List [str ]] = None ,
586
598
py_system_site_packages : bool = False ,
587
599
append_job_name : bool = True ,
588
- on_new_job_id_callback : Optional [Callable [[str ], None ]] = None
600
+ on_new_job_id_callback : Optional [Callable [[str ], None ]] = None ,
601
+ location : str = DEFAULT_DATAFLOW_LOCATION
589
602
):
590
603
"""
591
604
Starts Dataflow job.
@@ -620,9 +633,12 @@ def start_python_dataflow( # pylint: disable=too-many-arguments
620
633
If set to None or missing, the default project_id from the GCP connection is used.
621
634
:param on_new_job_id_callback: Callback called when the job ID is known.
622
635
:type on_new_job_id_callback: callable
636
+ :param location: Job location.
637
+ :type location: str
623
638
"""
624
639
name = self ._build_dataflow_job_name (job_name , append_job_name )
625
640
variables ['job_name' ] = name
641
+ variables ['region' ] = location
626
642
627
643
def label_formatter (labels_dict ):
628
644
return ['--labels={}={}' .format (key , value )
@@ -644,7 +660,8 @@ def label_formatter(labels_dict):
644
660
command_prefix = command_prefix ,
645
661
label_formatter = label_formatter ,
646
662
project_id = project_id ,
647
- on_new_job_id_callback = on_new_job_id_callback
663
+ on_new_job_id_callback = on_new_job_id_callback ,
664
+ location = location
648
665
)
649
666
else :
650
667
command_prefix = [py_interpreter ] + py_options + [dataflow ]
@@ -655,7 +672,8 @@ def label_formatter(labels_dict):
655
672
command_prefix = command_prefix ,
656
673
label_formatter = label_formatter ,
657
674
project_id = project_id ,
658
- on_new_job_id_callback = on_new_job_id_callback
675
+ on_new_job_id_callback = on_new_job_id_callback ,
676
+ location = location
659
677
)
660
678
661
679
@staticmethod
@@ -700,27 +718,38 @@ def _build_cmd(variables: Dict, label_formatter: Callable, project_id: str) -> L
700
718
command .append (f"--{ attr } ={ value } " )
701
719
return command
702
720
721
+ @_fallback_to_location_from_variables
703
722
@_fallback_to_project_id_from_variables
704
723
@GoogleBaseHook .fallback_to_default_project_id
705
- def is_job_dataflow_running (self , name : str , variables : Dict , project_id : str ) -> bool :
724
+ def is_job_dataflow_running (
725
+ self ,
726
+ name : str ,
727
+ project_id : str ,
728
+ location : str = DEFAULT_DATAFLOW_LOCATION ,
729
+ variables : Optional [Dict ] = None
730
+ ) -> bool :
706
731
"""
707
732
Helper method to check if jos is still running in dataflow
708
733
709
734
:param name: The name of the job.
710
735
:type name: str
711
- :param variables: Variables passed to the job.
712
- :type variables: dict
713
736
:param project_id: Optional, the GCP project ID in which to start a job.
714
737
If set to None or missing, the default project_id from the GCP connection is used.
738
+ :type project_id: str
739
+ :param location: Job location.
740
+ :type location: str
715
741
:return: True if job is running.
716
742
:rtype: bool
717
743
"""
718
- variables = self ._set_variables (variables )
744
+ if variables :
745
+ warnings .warn (
746
+ "The variables parameter has been deprecated. You should pass location using "
747
+ "the location parameter." , DeprecationWarning , stacklevel = 4 )
719
748
jobs_controller = _DataflowJobsController (
720
749
dataflow = self .get_conn (),
721
750
project_number = project_id ,
722
751
name = name ,
723
- location = variables [ 'region' ] ,
752
+ location = location ,
724
753
poll_sleep = self .poll_sleep
725
754
)
726
755
return jobs_controller .is_job_running ()
@@ -731,7 +760,7 @@ def cancel_job(
731
760
project_id : str ,
732
761
job_name : Optional [str ] = None ,
733
762
job_id : Optional [str ] = None ,
734
- location : Optional [ str ] = None ,
763
+ location : str = DEFAULT_DATAFLOW_LOCATION ,
735
764
) -> None :
736
765
"""
737
766
Cancels the job with the specified name prefix or Job ID.
@@ -753,7 +782,7 @@ def cancel_job(
753
782
project_number = project_id ,
754
783
name = job_name ,
755
784
job_id = job_id ,
756
- location = location or DEFAULT_DATAFLOW_LOCATION ,
785
+ location = location ,
757
786
poll_sleep = self .poll_sleep
758
787
)
759
788
jobs_controller .cancel ()
0 commit comments