Skip to content

Commit 309788e

Browse files
authored
Refactor DataprocOperators to support google-cloud-dataproc 2.0 (#13256)
1 parent f74da50 commit 309788e

File tree

8 files changed

+157
-144
lines changed

8 files changed

+157
-144
lines changed

β€Žairflow/providers/google/ADDITIONAL_INFO.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ Details are covered in the UPDATING.md files for each library, but there are som
3232
| [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) |
3333
| [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
3434
| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
35+
| [``google-cloud-dataproc``](https://pypi.org/project/google-cloud-dataproc/) | ``>=1.0.1,<2.0.0`` | ``>=2.2.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-dataproc/blob/master/UPGRADING.md) |
3536
| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) |
3637
| [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
3738
| [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) |
3839
| [``google-cloud-tasks``](https://pypi.org/project/google-cloud-tasks/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-tasks/blob/master/UPGRADING.md) |
3940

41+
4042
### The field names use the snake_case convention
4143

4244
If your DAG uses an object from the above mentioned libraries passed by XCom, it is necessary to update the naming convention of the fields that are read. Previously, the fields used the CamelSnake convention, now the snake_case convention is used.

β€Žairflow/providers/google/cloud/hooks/dataproc.py

Lines changed: 46 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,16 @@
2626
from google.api_core.exceptions import ServerError
2727
from google.api_core.retry import Retry
2828
from google.cloud.dataproc_v1beta2 import ( # pylint: disable=no-name-in-module
29-
ClusterControllerClient,
30-
JobControllerClient,
31-
WorkflowTemplateServiceClient,
32-
)
33-
from google.cloud.dataproc_v1beta2.types import ( # pylint: disable=no-name-in-module
3429
Cluster,
35-
Duration,
36-
FieldMask,
30+
ClusterControllerClient,
3731
Job,
32+
JobControllerClient,
3833
JobStatus,
3934
WorkflowTemplate,
35+
WorkflowTemplateServiceClient,
4036
)
37+
from google.protobuf.duration_pb2 import Duration
38+
from google.protobuf.field_mask_pb2 import FieldMask
4139

4240
from airflow.exceptions import AirflowException
4341
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
@@ -291,10 +289,12 @@ def create_cluster(
291289

292290
client = self.get_cluster_client(location=region)
293291
result = client.create_cluster(
294-
project_id=project_id,
295-
region=region,
296-
cluster=cluster,
297-
request_id=request_id,
292+
request={
293+
'project_id': project_id,
294+
'region': region,
295+
'cluster': cluster,
296+
'request_id': request_id,
297+
},
298298
retry=retry,
299299
timeout=timeout,
300300
metadata=metadata,
@@ -340,11 +340,13 @@ def delete_cluster(
340340
"""
341341
client = self.get_cluster_client(location=region)
342342
result = client.delete_cluster(
343-
project_id=project_id,
344-
region=region,
345-
cluster_name=cluster_name,
346-
cluster_uuid=cluster_uuid,
347-
request_id=request_id,
343+
request={
344+
'project_id': project_id,
345+
'region': region,
346+
'cluster_name': cluster_name,
347+
'cluster_uuid': cluster_uuid,
348+
'request_id': request_id,
349+
},
348350
retry=retry,
349351
timeout=timeout,
350352
metadata=metadata,
@@ -382,9 +384,7 @@ def diagnose_cluster(
382384
"""
383385
client = self.get_cluster_client(location=region)
384386
operation = client.diagnose_cluster(
385-
project_id=project_id,
386-
region=region,
387-
cluster_name=cluster_name,
387+
request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name},
388388
retry=retry,
389389
timeout=timeout,
390390
metadata=metadata,
@@ -423,9 +423,7 @@ def get_cluster(
423423
"""
424424
client = self.get_cluster_client(location=region)
425425
result = client.get_cluster(
426-
project_id=project_id,
427-
region=region,
428-
cluster_name=cluster_name,
426+
request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name},
429427
retry=retry,
430428
timeout=timeout,
431429
metadata=metadata,
@@ -467,10 +465,7 @@ def list_clusters(
467465
"""
468466
client = self.get_cluster_client(location=region)
469467
result = client.list_clusters(
470-
project_id=project_id,
471-
region=region,
472-
filter_=filter_,
473-
page_size=page_size,
468+
request={'project_id': project_id, 'region': region, 'filter': filter_, 'page_size': page_size},
474469
retry=retry,
475470
timeout=timeout,
476471
metadata=metadata,
@@ -551,13 +546,15 @@ def update_cluster( # pylint: disable=too-many-arguments
551546
"""
552547
client = self.get_cluster_client(location=location)
553548
operation = client.update_cluster(
554-
project_id=project_id,
555-
region=location,
556-
cluster_name=cluster_name,
557-
cluster=cluster,
558-
update_mask=update_mask,
559-
graceful_decommission_timeout=graceful_decommission_timeout,
560-
request_id=request_id,
549+
request={
550+
'project_id': project_id,
551+
'region': location,
552+
'cluster_name': cluster_name,
553+
'cluster': cluster,
554+
'update_mask': update_mask,
555+
'graceful_decommission_timeout': graceful_decommission_timeout,
556+
'request_id': request_id,
557+
},
561558
retry=retry,
562559
timeout=timeout,
563560
metadata=metadata,
@@ -593,10 +590,11 @@ def create_workflow_template(
593590
:param metadata: Additional metadata that is provided to the method.
594591
:type metadata: Sequence[Tuple[str, str]]
595592
"""
593+
metadata = metadata or ()
596594
client = self.get_template_client(location)
597-
parent = client.region_path(project_id, location)
595+
parent = f'projects/{project_id}/regions/{location}'
598596
return client.create_workflow_template(
599-
parent=parent, template=template, retry=retry, timeout=timeout, metadata=metadata
597+
request={'parent': parent, 'template': template}, retry=retry, timeout=timeout, metadata=metadata
600598
)
601599

602600
@GoogleBaseHook.fallback_to_default_project_id
@@ -643,13 +641,11 @@ def instantiate_workflow_template(
643641
:param metadata: Additional metadata that is provided to the method.
644642
:type metadata: Sequence[Tuple[str, str]]
645643
"""
644+
metadata = metadata or ()
646645
client = self.get_template_client(location)
647-
name = client.workflow_template_path(project_id, location, template_name)
646+
name = f'projects/{project_id}/regions/{location}/workflowTemplates/{template_name}'
648647
operation = client.instantiate_workflow_template(
649-
name=name,
650-
version=version,
651-
parameters=parameters,
652-
request_id=request_id,
648+
request={'name': name, 'version': version, 'request_id': request_id, 'parameters': parameters},
653649
retry=retry,
654650
timeout=timeout,
655651
metadata=metadata,
@@ -690,12 +686,11 @@ def instantiate_inline_workflow_template(
690686
:param metadata: Additional metadata that is provided to the method.
691687
:type metadata: Sequence[Tuple[str, str]]
692688
"""
689+
metadata = metadata or ()
693690
client = self.get_template_client(location)
694-
parent = client.region_path(project_id, location)
691+
parent = f'projects/{project_id}/regions/{location}'
695692
operation = client.instantiate_inline_workflow_template(
696-
parent=parent,
697-
template=template,
698-
request_id=request_id,
693+
request={'parent': parent, 'template': template, 'request_id': request_id},
699694
retry=retry,
700695
timeout=timeout,
701696
metadata=metadata,
@@ -722,19 +717,19 @@ def wait_for_job(
722717
"""
723718
state = None
724719
start = time.monotonic()
725-
while state not in (JobStatus.ERROR, JobStatus.DONE, JobStatus.CANCELLED):
720+
while state not in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED):
726721
if timeout and start + timeout < time.monotonic():
727722
raise AirflowException(f"Timeout: dataproc job {job_id} is not ready after {timeout}s")
728723
time.sleep(wait_time)
729724
try:
730-
job = self.get_job(location=location, job_id=job_id, project_id=project_id)
725+
job = self.get_job(project_id=project_id, location=location, job_id=job_id)
731726
state = job.status.state
732727
except ServerError as err:
733728
self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err)
734729

735-
if state == JobStatus.ERROR:
730+
if state == JobStatus.State.ERROR:
736731
raise AirflowException(f'Job failed:\n{job}')
737-
if state == JobStatus.CANCELLED:
732+
if state == JobStatus.State.CANCELLED:
738733
raise AirflowException(f'Job was cancelled:\n{job}')
739734

740735
@GoogleBaseHook.fallback_to_default_project_id
@@ -767,9 +762,7 @@ def get_job(
767762
"""
768763
client = self.get_job_client(location=location)
769764
job = client.get_job(
770-
project_id=project_id,
771-
region=location,
772-
job_id=job_id,
765+
request={'project_id': project_id, 'region': location, 'job_id': job_id},
773766
retry=retry,
774767
timeout=timeout,
775768
metadata=metadata,
@@ -812,10 +805,7 @@ def submit_job(
812805
"""
813806
client = self.get_job_client(location=location)
814807
return client.submit_job(
815-
project_id=project_id,
816-
region=location,
817-
job=job,
818-
request_id=request_id,
808+
request={'project_id': project_id, 'region': location, 'job': job, 'request_id': request_id},
819809
retry=retry,
820810
timeout=timeout,
821811
metadata=metadata,
@@ -884,9 +874,7 @@ def cancel_job(
884874
client = self.get_job_client(location=location)
885875

886876
job = client.cancel_job(
887-
project_id=project_id,
888-
region=location,
889-
job_id=job_id,
877+
request={'project_id': project_id, 'region': location, 'job_id': job_id},
890878
retry=retry,
891879
timeout=timeout,
892880
metadata=metadata,

β€Žairflow/providers/google/cloud/operators/dataproc.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# under the License.
1818
#
1919
"""This module contains Google Dataproc operators."""
20-
# pylint: disable=C0302
2120

2221
import inspect
2322
import ntpath
@@ -31,12 +30,9 @@
3130

3231
from google.api_core.exceptions import AlreadyExists, NotFound
3332
from google.api_core.retry import Retry, exponential_sleep_generator
34-
from google.cloud.dataproc_v1beta2.types import ( # pylint: disable=no-name-in-module
35-
Cluster,
36-
Duration,
37-
FieldMask,
38-
)
39-
from google.protobuf.json_format import MessageToDict
33+
from google.cloud.dataproc_v1beta2 import Cluster # pylint: disable=no-name-in-module
34+
from google.protobuf.duration_pb2 import Duration
35+
from google.protobuf.field_mask_pb2 import FieldMask
4036

4137
from airflow.exceptions import AirflowException
4238
from airflow.models import BaseOperator
@@ -562,7 +558,7 @@ def _get_cluster(self, hook: DataprocHook) -> Cluster:
562558
)
563559

564560
def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None:
565-
if cluster.status.state != cluster.status.ERROR:
561+
if cluster.status.state != cluster.status.State.ERROR:
566562
return
567563
self.log.info("Cluster is in ERROR state")
568564
gcs_uri = hook.diagnose_cluster(
@@ -590,7 +586,7 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster:
590586
time_left = self.timeout
591587
cluster = self._get_cluster(hook)
592588
for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120):
593-
if cluster.status.state != cluster.status.CREATING:
589+
if cluster.status.state != cluster.status.State.CREATING:
594590
break
595591
if time_left < 0:
596592
raise AirflowException(f"Cluster {self.cluster_name} is still CREATING state, aborting")
@@ -613,18 +609,18 @@ def execute(self, context) -> dict:
613609

614610
# Check if cluster is not in ERROR state
615611
self._handle_error_state(hook, cluster)
616-
if cluster.status.state == cluster.status.CREATING:
612+
if cluster.status.state == cluster.status.State.CREATING:
617613
# Wait for cluster to be be created
618614
cluster = self._wait_for_cluster_in_creating_state(hook)
619615
self._handle_error_state(hook, cluster)
620-
elif cluster.status.state == cluster.status.DELETING:
616+
elif cluster.status.state == cluster.status.State.DELETING:
621617
# Wait for cluster to be deleted
622618
self._wait_for_cluster_in_deleting_state(hook)
623619
# Create new cluster
624620
cluster = self._create_cluster(hook)
625621
self._handle_error_state(hook, cluster)
626622

627-
return MessageToDict(cluster)
623+
return Cluster.to_dict(cluster)
628624

629625

630626
class DataprocScaleClusterOperator(BaseOperator):
@@ -1855,7 +1851,7 @@ class DataprocSubmitJobOperator(BaseOperator):
18551851
:type wait_timeout: int
18561852
"""
18571853

1858-
template_fields = ('project_id', 'location', 'job', 'impersonation_chain')
1854+
template_fields = ('project_id', 'location', 'job', 'impersonation_chain', 'request_id')
18591855
template_fields_renderers = {"job": "json"}
18601856

18611857
@apply_defaults
@@ -1941,14 +1937,14 @@ class DataprocUpdateClusterOperator(BaseOperator):
19411937
example, to change the number of workers in a cluster to 5, the ``update_mask`` parameter would be
19421938
specified as ``config.worker_config.num_instances``, and the ``PATCH`` request body would specify the
19431939
new value. If a dict is provided, it must be of the same form as the protobuf message
1944-
:class:`~google.cloud.dataproc_v1beta2.types.FieldMask`
1945-
:type update_mask: Union[Dict, google.cloud.dataproc_v1beta2.types.FieldMask]
1940+
:class:`~google.protobuf.field_mask_pb2.FieldMask`
1941+
:type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask]
19461942
:param graceful_decommission_timeout: Optional. Timeout for graceful YARN decommissioning. Graceful
19471943
decommissioning allows removing nodes from the cluster without interrupting jobs in progress. Timeout
19481944
specifies how long to wait for jobs in progress to finish before forcefully removing nodes (and
19491945
potentially interrupting jobs). Default timeout is 0 (for forceful decommission), and the maximum
19501946
allowed timeout is 1 day.
1951-
:type graceful_decommission_timeout: Union[Dict, google.cloud.dataproc_v1beta2.types.Duration]
1947+
:type graceful_decommission_timeout: Union[Dict, google.protobuf.duration_pb2.Duration]
19521948
:param request_id: Optional. A unique id used to identify the request. If the server receives two
19531949
``UpdateClusterRequest`` requests with the same id, then the second request will be ignored and the
19541950
first ``google.longrunning.Operation`` created and stored in the backend is returned.
@@ -1974,7 +1970,7 @@ class DataprocUpdateClusterOperator(BaseOperator):
19741970
:type impersonation_chain: Union[str, Sequence[str]]
19751971
"""
19761972

1977-
template_fields = ('impersonation_chain',)
1973+
template_fields = ('impersonation_chain', 'cluster_name')
19781974

19791975
@apply_defaults
19801976
def __init__( # pylint: disable=too-many-arguments

β€Žairflow/providers/google/cloud/sensors/dataproc.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,18 @@ def poke(self, context: dict) -> bool:
6565
job = hook.get_job(job_id=self.dataproc_job_id, location=self.location, project_id=self.project_id)
6666
state = job.status.state
6767

68-
if state == JobStatus.ERROR:
68+
if state == JobStatus.State.ERROR:
6969
raise AirflowException(f'Job failed:\n{job}')
70-
elif state in {JobStatus.CANCELLED, JobStatus.CANCEL_PENDING, JobStatus.CANCEL_STARTED}:
70+
elif state in {
71+
JobStatus.State.CANCELLED,
72+
JobStatus.State.CANCEL_PENDING,
73+
JobStatus.State.CANCEL_STARTED,
74+
}:
7175
raise AirflowException(f'Job was cancelled:\n{job}')
72-
elif JobStatus.DONE == state:
76+
elif JobStatus.State.DONE == state:
7377
self.log.debug("Job %s completed successfully.", self.dataproc_job_id)
7478
return True
75-
elif JobStatus.ATTEMPT_FAILURE == state:
79+
elif JobStatus.State.ATTEMPT_FAILURE == state:
7680
self.log.debug("Job %s attempt has failed.", self.dataproc_job_id)
7781

7882
self.log.info("Waiting for job %s to complete.", self.dataproc_job_id)

β€Žsetup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def get_sphinx_theme_version() -> str:
286286
'google-cloud-bigtable>=1.0.0,<2.0.0',
287287
'google-cloud-container>=0.1.1,<2.0.0',
288288
'google-cloud-datacatalog>=3.0.0,<4.0.0',
289-
'google-cloud-dataproc>=1.0.1,<2.0.0',
289+
'google-cloud-dataproc>=2.2.0,<3.0.0',
290290
'google-cloud-dlp>=0.11.0,<2.0.0',
291291
'google-cloud-kms>=2.0.0,<3.0.0',
292292
'google-cloud-language>=1.1.1,<2.0.0',

0 commit comments

Comments
 (0)