|
24 | 24 | from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
25 | 25 |
|
26 | 26 | from google.api_core.exceptions import ServerError
|
| 27 | +from google.api_core.operation import Operation |
27 | 28 | from google.api_core.retry import Retry
|
28 | 29 | from google.cloud.dataproc_v1 import (
|
| 30 | + Batch, |
| 31 | + BatchControllerClient, |
29 | 32 | Cluster,
|
30 | 33 | ClusterControllerClient,
|
31 | 34 | Job,
|
@@ -267,6 +270,34 @@ def get_job_client(
|
267 | 270 | credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options
|
268 | 271 | )
|
269 | 272 |
|
| 273 | + def get_batch_client( |
| 274 | + self, region: Optional[str] = None, location: Optional[str] = None |
| 275 | + ) -> BatchControllerClient: |
| 276 | + """Returns BatchControllerClient""" |
| 277 | + if location is not None: |
| 278 | + warnings.warn( |
| 279 | + "Parameter `location` will be deprecated. " |
| 280 | + "Please provide value through `region` parameter instead.", |
| 281 | + DeprecationWarning, |
| 282 | + stacklevel=2, |
| 283 | + ) |
| 284 | + region = location |
| 285 | + client_options = None |
| 286 | + if region and region != 'global': |
| 287 | + client_options = {'api_endpoint': f'{region}-dataproc.googleapis.com:443'} |
| 288 | + |
| 289 | + return BatchControllerClient( |
| 290 | + credentials=self._get_credentials(), client_info=self.client_info, client_options=client_options |
| 291 | + ) |
| 292 | + |
| 293 | + def wait_for_operation(self, timeout: float, operation: Operation): |
| 294 | + """Waits for long-lasting operation to complete.""" |
| 295 | + try: |
| 296 | + return operation.result(timeout=timeout) |
| 297 | + except Exception: |
| 298 | + error = operation.exception(timeout=timeout) |
| 299 | + raise AirflowException(error) |
| 300 | + |
270 | 301 | @GoogleBaseHook.fallback_to_default_project_id
|
271 | 302 | def create_cluster(
|
272 | 303 | self,
|
@@ -1030,3 +1061,191 @@ def cancel_job(
|
1030 | 1061 | metadata=metadata,
|
1031 | 1062 | )
|
1032 | 1063 | return job
|
| 1064 | + |
| 1065 | + @GoogleBaseHook.fallback_to_default_project_id |
| 1066 | + def create_batch( |
| 1067 | + self, |
| 1068 | + region: str, |
| 1069 | + project_id: str, |
| 1070 | + batch: Union[Dict, Batch], |
| 1071 | + batch_id: Optional[str] = None, |
| 1072 | + request_id: Optional[str] = None, |
| 1073 | + retry: Optional[Retry] = None, |
| 1074 | + timeout: Optional[float] = None, |
| 1075 | + metadata: Optional[Sequence[Tuple[str, str]]] = "", |
| 1076 | + ): |
| 1077 | + """ |
| 1078 | + Creates a batch workload. |
| 1079 | +
|
| 1080 | + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. |
| 1081 | + :type project_id: str |
| 1082 | + :param region: Required. The Cloud Dataproc region in which to handle the request. |
| 1083 | + :type region: str |
| 1084 | + :param batch: Required. The batch to create. |
| 1085 | + :type batch: google.cloud.dataproc_v1.types.Batch |
| 1086 | + :param batch_id: Optional. The ID to use for the batch, which will become the final component |
| 1087 | + of the batch's resource name. |
| 1088 | + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. |
| 1089 | + :type batch_id: str |
| 1090 | + :param request_id: Optional. A unique id used to identify the request. If the server receives two |
| 1091 | + ``CreateBatchRequest`` requests with the same id, then the second request will be ignored and |
| 1092 | + the first ``google.longrunning.Operation`` created and stored in the backend is returned. |
| 1093 | + :type request_id: str |
| 1094 | + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be |
| 1095 | + retried. |
| 1096 | + :type retry: google.api_core.retry.Retry |
| 1097 | + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if |
| 1098 | + ``retry`` is specified, the timeout applies to each individual attempt. |
| 1099 | + :type timeout: float |
| 1100 | + :param metadata: Additional metadata that is provided to the method. |
| 1101 | + :type metadata: Sequence[Tuple[str, str]] |
| 1102 | + """ |
| 1103 | + client = self.get_batch_client(region) |
| 1104 | + parent = f'projects/{project_id}/regions/{region}' |
| 1105 | + |
| 1106 | + result = client.create_batch( |
| 1107 | + request={ |
| 1108 | + 'parent': parent, |
| 1109 | + 'batch': batch, |
| 1110 | + 'batch_id': batch_id, |
| 1111 | + 'request_id': request_id, |
| 1112 | + }, |
| 1113 | + retry=retry, |
| 1114 | + timeout=timeout, |
| 1115 | + metadata=metadata, |
| 1116 | + ) |
| 1117 | + return result |
| 1118 | + |
| 1119 | + @GoogleBaseHook.fallback_to_default_project_id |
| 1120 | + def delete_batch( |
| 1121 | + self, |
| 1122 | + batch_id: str, |
| 1123 | + region: str, |
| 1124 | + project_id: str, |
| 1125 | + retry: Optional[Retry] = None, |
| 1126 | + timeout: Optional[float] = None, |
| 1127 | + metadata: Optional[Sequence[Tuple[str, str]]] = None, |
| 1128 | + ): |
| 1129 | + """ |
| 1130 | + Deletes the batch workload resource. |
| 1131 | +
|
| 1132 | + :param batch_id: Required. The ID to use for the batch, which will become the final component |
| 1133 | + of the batch's resource name. |
| 1134 | + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. |
| 1135 | + :type batch_id: str |
| 1136 | + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. |
| 1137 | + :type project_id: str |
| 1138 | + :param region: Required. The Cloud Dataproc region in which to handle the request. |
| 1139 | + :type region: str |
| 1140 | + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be |
| 1141 | + retried. |
| 1142 | + :type retry: google.api_core.retry.Retry |
| 1143 | + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if |
| 1144 | + ``retry`` is specified, the timeout applies to each individual attempt. |
| 1145 | + :type timeout: float |
| 1146 | + :param metadata: Additional metadata that is provided to the method. |
| 1147 | + :type metadata: Sequence[Tuple[str, str]] |
| 1148 | + """ |
| 1149 | + client = self.get_batch_client(region) |
| 1150 | + name = f"projects/{project_id}/regions/{region}/batches/{batch_id}" |
| 1151 | + |
| 1152 | + result = client.delete_batch( |
| 1153 | + request={ |
| 1154 | + 'name': name, |
| 1155 | + }, |
| 1156 | + retry=retry, |
| 1157 | + timeout=timeout, |
| 1158 | + metadata=metadata, |
| 1159 | + ) |
| 1160 | + return result |
| 1161 | + |
| 1162 | + @GoogleBaseHook.fallback_to_default_project_id |
| 1163 | + def get_batch( |
| 1164 | + self, |
| 1165 | + batch_id: str, |
| 1166 | + region: str, |
| 1167 | + project_id: str, |
| 1168 | + retry: Optional[Retry] = None, |
| 1169 | + timeout: Optional[float] = None, |
| 1170 | + metadata: Optional[Sequence[Tuple[str, str]]] = None, |
| 1171 | + ): |
| 1172 | + """ |
| 1173 | + Gets the batch workload resource representation. |
| 1174 | +
|
| 1175 | + :param batch_id: Required. The ID to use for the batch, which will become the final component |
| 1176 | + of the batch's resource name. |
| 1177 | + This value must be 4-63 characters. Valid characters are /[a-z][0-9]-/. |
| 1178 | + :type batch_id: str |
| 1179 | + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. |
| 1180 | + :type project_id: str |
| 1181 | + :param region: Required. The Cloud Dataproc region in which to handle the request. |
| 1182 | + :type region: str |
| 1183 | + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be |
| 1184 | + retried. |
| 1185 | + :type retry: google.api_core.retry.Retry |
| 1186 | + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if |
| 1187 | + ``retry`` is specified, the timeout applies to each individual attempt. |
| 1188 | + :type timeout: float |
| 1189 | + :param metadata: Additional metadata that is provided to the method. |
| 1190 | + :type metadata: Sequence[Tuple[str, str]] |
| 1191 | + """ |
| 1192 | + client = self.get_batch_client(region) |
| 1193 | + name = f"projects/{project_id}/regions/{region}/batches/{batch_id}" |
| 1194 | + |
| 1195 | + result = client.get_batch( |
| 1196 | + request={ |
| 1197 | + 'name': name, |
| 1198 | + }, |
| 1199 | + retry=retry, |
| 1200 | + timeout=timeout, |
| 1201 | + metadata=metadata, |
| 1202 | + ) |
| 1203 | + return result |
| 1204 | + |
| 1205 | + @GoogleBaseHook.fallback_to_default_project_id |
| 1206 | + def list_batches( |
| 1207 | + self, |
| 1208 | + region: str, |
| 1209 | + project_id: str, |
| 1210 | + page_size: Optional[int] = None, |
| 1211 | + page_token: Optional[str] = None, |
| 1212 | + retry: Optional[Retry] = None, |
| 1213 | + timeout: Optional[float] = None, |
| 1214 | + metadata: Optional[Sequence[Tuple[str, str]]] = None, |
| 1215 | + ): |
| 1216 | + """ |
| 1217 | + Lists batch workloads. |
| 1218 | +
|
| 1219 | + :param project_id: Required. The ID of the Google Cloud project that the cluster belongs to. |
| 1220 | + :type project_id: str |
| 1221 | + :param region: Required. The Cloud Dataproc region in which to handle the request. |
| 1222 | + :type region: str |
| 1223 | + :param page_size: Optional. The maximum number of batches to return in each response. The service may |
| 1224 | + return fewer than this value. The default page size is 20; the maximum page size is 1000. |
| 1225 | + :type page_size: int |
| 1226 | + :param page_token: Optional. A page token received from a previous ``ListBatches`` call. |
| 1227 | + Provide this token to retrieve the subsequent page. |
| 1228 | + :type page_token: str |
| 1229 | + :param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be |
| 1230 | + retried. |
| 1231 | + :type retry: google.api_core.retry.Retry |
| 1232 | + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if |
| 1233 | + ``retry`` is specified, the timeout applies to each individual attempt. |
| 1234 | + :type timeout: float |
| 1235 | + :param metadata: Additional metadata that is provided to the method. |
| 1236 | + :type metadata: Sequence[Tuple[str, str]] |
| 1237 | + """ |
| 1238 | + client = self.get_batch_client(region) |
| 1239 | + parent = f'projects/{project_id}/regions/{region}' |
| 1240 | + |
| 1241 | + result = client.list_batches( |
| 1242 | + request={ |
| 1243 | + 'parent': parent, |
| 1244 | + 'page_size': page_size, |
| 1245 | + 'page_token': page_token, |
| 1246 | + }, |
| 1247 | + retry=retry, |
| 1248 | + timeout=timeout, |
| 1249 | + metadata=metadata, |
| 1250 | + ) |
| 1251 | + return result |
0 commit comments