Skip to content

Commit bf68b9a

Browse files
authored
Create dataproc serverless spark batches operator (#19248)
1 parent c97a2e8 commit bf68b9a

File tree

9 files changed

+898
-3
lines changed

9 files changed

+898
-3
lines changed

β€Žairflow/providers/google/cloud/example_dags/example_dataproc.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,14 @@
2626
from airflow import models
2727
from airflow.providers.google.cloud.operators.dataproc import (
2828
ClusterGenerator,
29+
DataprocCreateBatchOperator,
2930
DataprocCreateClusterOperator,
3031
DataprocCreateWorkflowTemplateOperator,
32+
DataprocDeleteBatchOperator,
3133
DataprocDeleteClusterOperator,
34+
DataprocGetBatchOperator,
3235
DataprocInstantiateWorkflowTemplateOperator,
36+
DataprocListBatchesOperator,
3337
DataprocSubmitJobOperator,
3438
DataprocUpdateClusterOperator,
3539
)
@@ -174,6 +178,13 @@
174178
},
175179
"jobs": [{"step_id": "pig_job_1", "pig_job": PIG_JOB["pig_job"]}],
176180
}
181+
BATCH_ID = "test-batch-id"
182+
BATCH_CONFIG = {
183+
"spark_batch": {
184+
"jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
185+
"main_class": "org.apache.spark.examples.SparkPi",
186+
},
187+
}
177188

178189

179190
with models.DAG(
@@ -282,3 +293,41 @@
282293

283294
# Task dependency created via `XComArgs`:
284295
# spark_task_async >> spark_task_async_sensor
296+
297+
with models.DAG(
298+
"example_gcp_batch_dataproc",
299+
schedule_interval='@once',
300+
start_date=datetime(2021, 1, 1),
301+
catchup=False,
302+
) as dag_batch:
303+
# [START how_to_cloud_dataproc_create_batch_operator]
304+
create_batch = DataprocCreateBatchOperator(
305+
task_id="create_batch",
306+
project_id=PROJECT_ID,
307+
region=REGION,
308+
batch=BATCH_CONFIG,
309+
batch_id=BATCH_ID,
310+
)
311+
# [END how_to_cloud_dataproc_create_batch_operator]
312+
313+
# [START how_to_cloud_dataproc_get_batch_operator]
314+
get_batch = DataprocGetBatchOperator(
315+
task_id="get_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
316+
)
317+
# [END how_to_cloud_dataproc_get_batch_operator]
318+
319+
# [START how_to_cloud_dataproc_list_batches_operator]
320+
list_batches = DataprocListBatchesOperator(
321+
task_id="list_batches",
322+
project_id=PROJECT_ID,
323+
region=REGION,
324+
)
325+
# [END how_to_cloud_dataproc_list_batches_operator]
326+
327+
# [START how_to_cloud_dataproc_delete_batch_operator]
328+
delete_batch = DataprocDeleteBatchOperator(
329+
task_id="delete_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
330+
)
331+
# [END how_to_cloud_dataproc_delete_batch_operator]
332+
333+
create_batch >> get_batch >> list_batches >> delete_batch

β€Žairflow/providers/google/cloud/hooks/dataproc.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@
2424
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
2525

2626
from google.api_core.exceptions import ServerError
27+
from google.api_core.operation import Operation
2728
from google.api_core.retry import Retry
2829
from google.cloud.dataproc_v1 import (
30+
Batch,
31+
BatchControllerClient,
2932
Cluster,
3033
ClusterControllerClient,
3134
Job,
@@ -267,6 +270,34 @@ def get_job_client(
267270
credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
268271
)
269272

273+
def get_batch_client(
274+
self, region: Optional[str] = None, location: Optional[str] = None
275+
) -> BatchControllerClient:
276+
"""Returns BatchControllerClient"""
277+
if location is not None:
278+
warnings.warn(
279+
"Parameter `location` will be deprecated. "
280+
"Please provide value through `region` parameter instead.",
281+
DeprecationWarning,
282+
stacklevel=2,
283+
)
284+
region = location
285+
client_options = None
286+
if region and region != 'global':
287+
client_options = {'api_endpoint': f'{region}-dataproc.googleapis.com:443'}
288+
289+
return BatchControllerClient(
290+
credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
291+
)
292+
293+
def wait_for_operation(self, timeout: float, operation: Operation):
294+
"""Waits for long-lasting operation to complete."""
295+
try:
296+
return operation.result(timeout=timeout)
297+
except Exception:
298+
error = operation.exception(timeout=timeout)
299+
raise AirflowException(error)
300+
270301
@GoogleBaseHook.fallback_to_default_project_id
271302
def create_cluster(
272303
self,
@@ -1030,3 +1061,191 @@ def cancel_job(
10301061
metadata=metadata,
10311062
)
10321063
return job
1064+
1065+
@GoogleBaseHook.fallback_to_default_project_id
1066+
def create_batch(
1067+
self,
1068+
region: str,
1069+
project_id: str,
1070+
batch: Union[Dict, Batch],
1071+
batch_id: Optional[str] = None,
1072+
request_id: Optional[str] = None,
1073+
retry: Optional[Retry] = None,
1074+
timeout: Optional[float] = None,
1075+
metadata: Optional[Sequence[Tuple[str, str]]] = "",
1076+
):
1077+
"""
1078+
Creates a batch workload.
1079+
1080+
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
1081+
:type project_id: str
1082+
:param region: Required. The Cloud Dataproc region in which to handle the request.
1083+
:type region: str
1084+
:param batch: Required. The batch to create.
1085+
:type batch: google.cloud.dataproc_v1.types.Batch
1086+
:param batch_id: Optional. The ID to use for the batch, which will become the final component
1087+
of the batch's resource name.
1088+
This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
1089+
:type batch_id: str
1090+
:param request_id: Optional. A unique id used to identify the request. If the server receives two
1091+
``CreateBatchRequest`` requests with the same id, then the second request will be ignored and
1092+
the first ``google.longrunning.Operation`` created and stored in the backend is returned.
1093+
:type request_id: str
1094+
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
1095+
retried.
1096+
:type retry: google.api_core.retry.Retry
1097+
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
1098+
``retry`` is specified, the timeout applies to each individual attempt.
1099+
:type timeout: float
1100+
:param metadata: Additional metadata that is provided to the method.
1101+
:type metadata: Sequence[Tuple[str, str]]
1102+
"""
1103+
client = self.get_batch_client(region)
1104+
parent = f'projects/{project_id}/regions/{region}'
1105+
1106+
result = client.create_batch(
1107+
request={
1108+
'parent': parent,
1109+
'batch': batch,
1110+
'batch_id': batch_id,
1111+
'request_id': request_id,
1112+
},
1113+
retry=retry,
1114+
timeout=timeout,
1115+
metadata=metadata,
1116+
)
1117+
return result
1118+
1119+
@GoogleBaseHook.fallback_to_default_project_id
1120+
def delete_batch(
1121+
self,
1122+
batch_id: str,
1123+
region: str,
1124+
project_id: str,
1125+
retry: Optional[Retry] = None,
1126+
timeout: Optional[float] = None,
1127+
metadata: Optional[Sequence[Tuple[str, str]]] = None,
1128+
):
1129+
"""
1130+
Deletes the batch workload resource.
1131+
1132+
:param batch_id: Required. The ID to use for the batch, which will become the final component
1133+
of the batch's resource name.
1134+
This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
1135+
:type batch_id: str
1136+
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
1137+
:type project_id: str
1138+
:param region: Required. The Cloud Dataproc region in which to handle the request.
1139+
:type region: str
1140+
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
1141+
retried.
1142+
:type retry: google.api_core.retry.Retry
1143+
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
1144+
``retry`` is specified, the timeout applies to each individual attempt.
1145+
:type timeout: float
1146+
:param metadata: Additional metadata that is provided to the method.
1147+
:type metadata: Sequence[Tuple[str, str]]
1148+
"""
1149+
client = self.get_batch_client(region)
1150+
name = f"projects/{project_id}/regions/{region}/batches/{batch_id}"
1151+
1152+
result = client.delete_batch(
1153+
request={
1154+
'name': name,
1155+
},
1156+
retry=retry,
1157+
timeout=timeout,
1158+
metadata=metadata,
1159+
)
1160+
return result
1161+
1162+
@GoogleBaseHook.fallback_to_default_project_id
1163+
def get_batch(
1164+
self,
1165+
batch_id: str,
1166+
region: str,
1167+
project_id: str,
1168+
retry: Optional[Retry] = None,
1169+
timeout: Optional[float] = None,
1170+
metadata: Optional[Sequence[Tuple[str, str]]] = None,
1171+
):
1172+
"""
1173+
Gets the batch workload resource representation.
1174+
1175+
:param batch_id: Required. The ID to use for the batch, which will become the final component
1176+
of the batch's resource name.
1177+
This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/.
1178+
:type batch_id: str
1179+
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
1180+
:type project_id: str
1181+
:param region: Required. The Cloud Dataproc region in which to handle the request.
1182+
:type region: str
1183+
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
1184+
retried.
1185+
:type retry: google.api_core.retry.Retry
1186+
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
1187+
``retry`` is specified, the timeout applies to each individual attempt.
1188+
:type timeout: float
1189+
:param metadata: Additional metadata that is provided to the method.
1190+
:type metadata: Sequence[Tuple[str, str]]
1191+
"""
1192+
client = self.get_batch_client(region)
1193+
name = f"projects/{project_id}/regions/{region}/batches/{batch_id}"
1194+
1195+
result = client.get_batch(
1196+
request={
1197+
'name': name,
1198+
},
1199+
retry=retry,
1200+
timeout=timeout,
1201+
metadata=metadata,
1202+
)
1203+
return result
1204+
1205+
@GoogleBaseHook.fallback_to_default_project_id
1206+
def list_batches(
1207+
self,
1208+
region: str,
1209+
project_id: str,
1210+
page_size: Optional[int] = None,
1211+
page_token: Optional[str] = None,
1212+
retry: Optional[Retry] = None,
1213+
timeout: Optional[float] = None,
1214+
metadata: Optional[Sequence[Tuple[str, str]]] = None,
1215+
):
1216+
"""
1217+
Lists batch workloads.
1218+
1219+
:param project_id: Required. The ID of the Google Cloud project that the cluster belongs to.
1220+
:type project_id: str
1221+
:param region: Required. The Cloud Dataproc region in which to handle the request.
1222+
:type region: str
1223+
:param page_size: Optional. The maximum number of batches to return in each response. The service may
1224+
return fewer than this value. The default page size is 20; the maximum page size is 1000.
1225+
:type page_size: int
1226+
:param page_token: Optional. A page token received from a previous ``ListBatches`` call.
1227+
Provide this token to retrieve the subsequent page.
1228+
:type page_token: str
1229+
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
1230+
retried.
1231+
:type retry: google.api_core.retry.Retry
1232+
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
1233+
``retry`` is specified, the timeout applies to each individual attempt.
1234+
:type timeout: float
1235+
:param metadata: Additional metadata that is provided to the method.
1236+
:type metadata: Sequence[Tuple[str, str]]
1237+
"""
1238+
client = self.get_batch_client(region)
1239+
parent = f'projects/{project_id}/regions/{region}'
1240+
1241+
result = client.list_batches(
1242+
request={
1243+
'parent': parent,
1244+
'page_size': page_size,
1245+
'page_token': page_token,
1246+
},
1247+
retry=retry,
1248+
timeout=timeout,
1249+
metadata=metadata,
1250+
)
1251+
return result

0 commit comments

Comments
 (0)