Skip to content

Commit c6a014a

Browse files
authored
Add CloudBatchHook and operators (#32606)
1 parent 2c0fa0c commit c6a014a

File tree

11 files changed

+1694
-0
lines changed

11 files changed

+1694
-0
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
import itertools
21+
import json
22+
from time import sleep
23+
from typing import Iterable, Sequence
24+
25+
from google.api_core import operation # type: ignore
26+
from google.cloud.batch import ListJobsRequest, ListTasksRequest
27+
from google.cloud.batch_v1 import (
28+
BatchServiceAsyncClient,
29+
BatchServiceClient,
30+
CreateJobRequest,
31+
Job,
32+
JobStatus,
33+
Task,
34+
)
35+
from google.cloud.batch_v1.services.batch_service import pagers
36+
37+
from airflow.exceptions import AirflowException
38+
from airflow.providers.google.common.consts import CLIENT_INFO
39+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
40+
41+
42+
class CloudBatchHook(GoogleBaseHook):
43+
"""
44+
Hook for the Google Cloud Batch service.
45+
46+
:param gcp_conn_id: The connection ID to use when fetching connection info.
47+
:param impersonation_chain: Optional service account to impersonate using short-term
48+
credentials, or chained list of accounts required to get the access_token
49+
of the last account in the list, which will be impersonated in the request.
50+
If set as a string, the account must grant the originating account
51+
the Service Account Token Creator IAM role.
52+
If set as a sequence, the identities from the list must grant
53+
Service Account Token Creator IAM role to the directly preceding identity, with first
54+
account from the list granting this role to the originating account.
55+
"""
56+
57+
def __init__(
58+
self,
59+
gcp_conn_id: str = "google_cloud_default",
60+
impersonation_chain: str | Sequence[str] | None = None,
61+
) -> None:
62+
super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain)
63+
self._client: BatchServiceClient | None = None
64+
65+
def get_conn(self):
66+
"""
67+
Retrieves connection to GCE Batch.
68+
69+
:return: Google Batch Service client object.
70+
"""
71+
if self._client is None:
72+
self._client = BatchServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
73+
return self._client
74+
75+
@GoogleBaseHook.fallback_to_default_project_id
76+
def submit_batch_job(
77+
self, job_name: str, job: Job, region: str, project_id: str = PROVIDE_PROJECT_ID
78+
) -> Job:
79+
if isinstance(job, dict):
80+
job = Job.from_json(json.dumps(job))
81+
82+
create_request = CreateJobRequest()
83+
create_request.job = job
84+
create_request.job_id = job_name
85+
create_request.parent = f"projects/{project_id}/locations/{region}"
86+
87+
return self.get_conn().create_job(create_request)
88+
89+
@GoogleBaseHook.fallback_to_default_project_id
90+
def delete_job(
91+
self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID
92+
) -> operation.Operation:
93+
return self.get_conn().delete_job(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
94+
95+
@GoogleBaseHook.fallback_to_default_project_id
96+
def list_jobs(
97+
self,
98+
region: str,
99+
project_id: str = PROVIDE_PROJECT_ID,
100+
filter: str | None = None,
101+
limit: int | None = None,
102+
) -> Iterable[Job]:
103+
104+
if limit is not None and limit < 0:
105+
raise AirflowException("The limit for the list jobs request should be greater or equal to zero")
106+
107+
list_jobs_request: ListJobsRequest = ListJobsRequest(
108+
parent=f"projects/{project_id}/locations/{region}", filter=filter
109+
)
110+
111+
jobs: pagers.ListJobsPager = self.get_conn().list_jobs(request=list_jobs_request)
112+
113+
return list(itertools.islice(jobs, limit))
114+
115+
@GoogleBaseHook.fallback_to_default_project_id
116+
def list_tasks(
117+
self,
118+
region: str,
119+
job_name: str,
120+
project_id: str = PROVIDE_PROJECT_ID,
121+
group_name: str = "group0",
122+
filter: str | None = None,
123+
limit: int | None = None,
124+
) -> Iterable[Task]:
125+
126+
if limit is not None and limit < 0:
127+
raise AirflowException("The limit for the list tasks request should be greater or equal to zero")
128+
129+
list_tasks_request: ListTasksRequest = ListTasksRequest(
130+
parent=f"projects/{project_id}/locations/{region}/jobs/{job_name}/taskGroups/{group_name}",
131+
filter=filter,
132+
)
133+
134+
tasks: pagers.ListTasksPager = self.get_conn().list_tasks(request=list_tasks_request)
135+
136+
return list(itertools.islice(tasks, limit))
137+
138+
def wait_for_job(
139+
self, job_name: str, polling_period_seconds: float = 10, timeout: float | None = None
140+
) -> Job:
141+
client = self.get_conn()
142+
while timeout is None or timeout > 0:
143+
try:
144+
job = client.get_job(name=f"{job_name}")
145+
status: JobStatus.State = job.status.state
146+
if (
147+
status == JobStatus.State.SUCCEEDED
148+
or status == JobStatus.State.FAILED
149+
or status == JobStatus.State.DELETION_IN_PROGRESS
150+
):
151+
return job
152+
else:
153+
sleep(polling_period_seconds)
154+
except Exception as e:
155+
self.log.exception("Exception occurred while checking for job completion.")
156+
raise e
157+
158+
if timeout is not None:
159+
timeout -= polling_period_seconds
160+
161+
raise AirflowException(f"Job with name [{job_name}] timed out")
162+
163+
def get_job(self, job_name) -> Job:
164+
return self.get_conn().get_job(name=job_name)
165+
166+
167+
class CloudBatchAsyncHook(GoogleBaseHook):
168+
"""
169+
Async hook for the Google Cloud Batch service.
170+
171+
:param gcp_conn_id: The connection ID to use when fetching connection info.
172+
:param impersonation_chain: Optional service account to impersonate using short-term
173+
credentials, or chained list of accounts required to get the access_token
174+
of the last account in the list, which will be impersonated in the request.
175+
If set as a string, the account must grant the originating account
176+
the Service Account Token Creator IAM role.
177+
If set as a sequence, the identities from the list must grant
178+
Service Account Token Creator IAM role to the directly preceding identity, with first
179+
account from the list granting this role to the originating account.
180+
"""
181+
182+
def __init__(
183+
self,
184+
gcp_conn_id: str = "google_cloud_default",
185+
impersonation_chain: str | Sequence[str] | None = None,
186+
):
187+
188+
self._client: BatchServiceAsyncClient | None = None
189+
super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain)
190+
191+
def get_conn(self):
192+
if self._client is None:
193+
self._client = BatchServiceAsyncClient(
194+
credentials=self.get_credentials(), client_info=CLIENT_INFO
195+
)
196+
197+
return self._client
198+
199+
async def get_batch_job(
200+
self,
201+
job_name: str,
202+
) -> Job:
203+
client = self.get_conn()
204+
return await client.get_job(name=f"{job_name}")

0 commit comments

Comments
 (0)