26
26
import uuid
27
27
import warnings
28
28
from datetime import datetime , timedelta
29
- from typing import TYPE_CHECKING , Sequence
29
+ from typing import TYPE_CHECKING , Any , Sequence
30
30
31
31
from google .api_core import operation # type: ignore
32
32
from google .api_core .exceptions import AlreadyExists , NotFound
33
33
from google .api_core .gapic_v1 .method import DEFAULT , _MethodDefault
34
34
from google .api_core .retry import Retry , exponential_sleep_generator
35
- from google .cloud .dataproc_v1 import Batch , Cluster , JobStatus
35
+ from google .cloud .dataproc_v1 import Batch , Cluster , ClusterStatus , JobStatus
36
36
from google .protobuf .duration_pb2 import Duration
37
37
from google .protobuf .field_mask_pb2 import FieldMask
38
38
50
50
DataprocLink ,
51
51
DataprocListLink ,
52
52
)
53
- from airflow .providers .google .cloud .triggers .dataproc import DataprocBaseTrigger
53
+ from airflow .providers .google .cloud .triggers .dataproc import DataprocClusterTrigger , DataprocSubmitTrigger
54
54
from airflow .utils import timezone
55
55
56
56
if TYPE_CHECKING :
@@ -438,6 +438,8 @@ class DataprocCreateClusterOperator(BaseOperator):
438
438
If set as a sequence, the identities from the list must grant
439
439
Service Account Token Creator IAM role to the directly preceding identity, with first
440
440
account from the list granting this role to the originating account (templated).
441
+ :param deferrable: Run operator in the deferrable mode.
442
+ :param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
441
443
"""
442
444
443
445
template_fields : Sequence [str ] = (
@@ -470,6 +472,8 @@ def __init__(
470
472
metadata : Sequence [tuple [str , str ]] = (),
471
473
gcp_conn_id : str = "google_cloud_default" ,
472
474
impersonation_chain : str | Sequence [str ] | None = None ,
475
+ deferrable : bool = False ,
476
+ polling_interval_seconds : int = 10 ,
473
477
** kwargs ,
474
478
) -> None :
475
479
@@ -502,7 +506,8 @@ def __init__(
502
506
del kwargs [arg ]
503
507
504
508
super ().__init__ (** kwargs )
505
-
509
+ if deferrable and polling_interval_seconds <= 0 :
510
+ raise ValueError ("Invalid value for polling_interval_seconds. Expected value greater than 0" )
506
511
self .cluster_config = cluster_config
507
512
self .cluster_name = cluster_name
508
513
self .labels = labels
@@ -517,9 +522,11 @@ def __init__(
517
522
self .use_if_exists = use_if_exists
518
523
self .impersonation_chain = impersonation_chain
519
524
self .virtual_cluster_config = virtual_cluster_config
525
+ self .deferrable = deferrable
526
+ self .polling_interval_seconds = polling_interval_seconds
520
527
521
528
def _create_cluster (self , hook : DataprocHook ):
522
- operation = hook .create_cluster (
529
+ return hook .create_cluster (
523
530
project_id = self .project_id ,
524
531
region = self .region ,
525
532
cluster_name = self .cluster_name ,
@@ -531,9 +538,6 @@ def _create_cluster(self, hook: DataprocHook):
531
538
timeout = self .timeout ,
532
539
metadata = self .metadata ,
533
540
)
534
- cluster = operation .result ()
535
- self .log .info ("Cluster created." )
536
- return cluster
537
541
538
542
def _delete_cluster (self , hook ):
539
543
self .log .info ("Deleting the cluster" )
@@ -596,7 +600,25 @@ def execute(self, context: Context) -> dict:
596
600
)
597
601
try :
598
602
# First try to create a new cluster
599
- cluster = self ._create_cluster (hook )
603
+ operation = self ._create_cluster (hook )
604
+ if not self .deferrable :
605
+ cluster = hook .wait_for_operation (
606
+ timeout = self .timeout , result_retry = self .retry , operation = operation
607
+ )
608
+ self .log .info ("Cluster created." )
609
+ return Cluster .to_dict (cluster )
610
+ else :
611
+ self .defer (
612
+ trigger = DataprocClusterTrigger (
613
+ cluster_name = self .cluster_name ,
614
+ project_id = self .project_id ,
615
+ region = self .region ,
616
+ gcp_conn_id = self .gcp_conn_id ,
617
+ impersonation_chain = self .impersonation_chain ,
618
+ polling_interval_seconds = self .polling_interval_seconds ,
619
+ ),
620
+ method_name = "execute_complete" ,
621
+ )
600
622
except AlreadyExists :
601
623
if not self .use_if_exists :
602
624
raise
@@ -618,6 +640,21 @@ def execute(self, context: Context) -> dict:
618
640
619
641
return Cluster .to_dict (cluster )
620
642
643
+ def execute_complete (self , context : Context , event : dict [str , Any ]) -> Any :
644
+ """
645
+ Callback for when the trigger fires - returns immediately.
646
+ Relies on trigger to throw an exception, otherwise it assumes execution was
647
+ successful.
648
+ """
649
+ cluster_state = event ["cluster_state" ]
650
+ cluster_name = event ["cluster_name" ]
651
+
652
+ if cluster_state == ClusterStatus .State .ERROR :
653
+ raise AirflowException (f"Cluster is in ERROR state:\n { cluster_name } " )
654
+
655
+ self .log .info ("%s completed successfully." , self .task_id )
656
+ return event ["cluster" ]
657
+
621
658
622
659
class DataprocScaleClusterOperator (BaseOperator ):
623
660
"""
@@ -974,7 +1011,7 @@ def execute(self, context: Context):
974
1011
975
1012
if self .deferrable :
976
1013
self .defer (
977
- trigger = DataprocBaseTrigger (
1014
+ trigger = DataprocSubmitTrigger (
978
1015
job_id = job_id ,
979
1016
project_id = self .project_id ,
980
1017
region = self .region ,
@@ -1888,7 +1925,7 @@ def execute(self, context: Context):
1888
1925
self .job_id = new_job_id
1889
1926
if self .deferrable :
1890
1927
self .defer (
1891
- trigger = DataprocBaseTrigger (
1928
+ trigger = DataprocSubmitTrigger (
1892
1929
job_id = self .job_id ,
1893
1930
project_id = self .project_id ,
1894
1931
region = self .region ,
@@ -1964,6 +2001,8 @@ class DataprocUpdateClusterOperator(BaseOperator):
1964
2001
If set as a sequence, the identities from the list must grant
1965
2002
Service Account Token Creator IAM role to the directly preceding identity, with first
1966
2003
account from the list granting this role to the originating account (templated).
2004
+ :param deferrable: Run operator in the deferrable mode.
2005
+ :param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
1967
2006
"""
1968
2007
1969
2008
template_fields : Sequence [str ] = (
@@ -1991,9 +2030,13 @@ def __init__(
1991
2030
metadata : Sequence [tuple [str , str ]] = (),
1992
2031
gcp_conn_id : str = "google_cloud_default" ,
1993
2032
impersonation_chain : str | Sequence [str ] | None = None ,
2033
+ deferrable : bool = False ,
2034
+ polling_interval_seconds : int = 10 ,
1994
2035
** kwargs ,
1995
2036
):
1996
2037
super ().__init__ (** kwargs )
2038
+ if deferrable and polling_interval_seconds <= 0 :
2039
+ raise ValueError ("Invalid value for polling_interval_seconds. Expected value greater than 0" )
1997
2040
self .project_id = project_id
1998
2041
self .region = region
1999
2042
self .cluster_name = cluster_name
@@ -2006,6 +2049,8 @@ def __init__(
2006
2049
self .metadata = metadata
2007
2050
self .gcp_conn_id = gcp_conn_id
2008
2051
self .impersonation_chain = impersonation_chain
2052
+ self .deferrable = deferrable
2053
+ self .polling_interval_seconds = polling_interval_seconds
2009
2054
2010
2055
def execute (self , context : Context ):
2011
2056
hook = DataprocHook (gcp_conn_id = self .gcp_conn_id , impersonation_chain = self .impersonation_chain )
@@ -2026,9 +2071,36 @@ def execute(self, context: Context):
2026
2071
timeout = self .timeout ,
2027
2072
metadata = self .metadata ,
2028
2073
)
2029
- operation .result ()
2074
+
2075
+ if not self .deferrable :
2076
+ hook .wait_for_operation (timeout = self .timeout , result_retry = self .retry , operation = operation )
2077
+ else :
2078
+ self .defer (
2079
+ trigger = DataprocClusterTrigger (
2080
+ cluster_name = self .cluster_name ,
2081
+ project_id = self .project_id ,
2082
+ region = self .region ,
2083
+ gcp_conn_id = self .gcp_conn_id ,
2084
+ impersonation_chain = self .impersonation_chain ,
2085
+ polling_interval_seconds = self .polling_interval_seconds ,
2086
+ ),
2087
+ method_name = "execute_complete" ,
2088
+ )
2030
2089
self .log .info ("Updated %s cluster." , self .cluster_name )
2031
2090
2091
+ def execute_complete (self , context : Context , event : dict [str , Any ]) -> Any :
2092
+ """
2093
+ Callback for when the trigger fires - returns immediately.
2094
+ Relies on trigger to throw an exception, otherwise it assumes execution was
2095
+ successful.
2096
+ """
2097
+ cluster_state = event ["cluster_state" ]
2098
+ cluster_name = event ["cluster_name" ]
2099
+
2100
+ if cluster_state == ClusterStatus .State .ERROR :
2101
+ raise AirflowException (f"Cluster is in ERROR state:\n { cluster_name } " )
2102
+ self .log .info ("%s completed successfully." , self .task_id )
2103
+
2032
2104
2033
2105
class DataprocCreateBatchOperator (BaseOperator ):
2034
2106
"""
0 commit comments