Skip to content

Commit da8f133

Browse files
author
Łukasz Wyszomirski
authored
Use AsyncClient for Composer Operators in deferrable mode (#25951)
1 parent 57fc3e9 commit da8f133

File tree

4 files changed

+212
-8
lines changed

4 files changed

+212
-8
lines changed

β€Žairflow/providers/google/cloud/hooks/cloud_composer.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@
2121
from google.api_core.client_options import ClientOptions
2222
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
2323
from google.api_core.operation import Operation
24+
from google.api_core.operation_async import AsyncOperation
2425
from google.api_core.retry import Retry
25-
from google.cloud.orchestration.airflow.service_v1 import EnvironmentsClient, ImageVersionsClient
26+
from google.cloud.orchestration.airflow.service_v1 import (
27+
EnvironmentsAsyncClient,
28+
EnvironmentsClient,
29+
ImageVersionsClient,
30+
)
2631
from google.cloud.orchestration.airflow.service_v1.services.environments.pagers import ListEnvironmentsPager
2732
from google.cloud.orchestration.airflow.service_v1.services.image_versions.pagers import (
2833
ListImageVersionsPager,
@@ -275,3 +280,123 @@ def list_image_versions(
275280
metadata=metadata,
276281
)
277282
return result
283+
284+
285+
class CloudComposerAsyncHook(GoogleBaseHook):
286+
"""Hook for Google Cloud Composer async APIs."""
287+
288+
client_options = ClientOptions(api_endpoint='composer.googleapis.com:443')
289+
290+
def get_environment_client(self) -> EnvironmentsAsyncClient:
291+
"""Retrieves client library object that allow access Environments service."""
292+
return EnvironmentsAsyncClient(
293+
credentials=self.get_credentials(),
294+
client_info=CLIENT_INFO,
295+
client_options=self.client_options,
296+
)
297+
298+
def get_environment_name(self, project_id, region, environment_id):
299+
return f'projects/{project_id}/locations/{region}/environments/{environment_id}'
300+
301+
def get_parent(self, project_id, region):
302+
return f'projects/{project_id}/locations/{region}'
303+
304+
async def get_operation(self, operation_name):
305+
return await self.get_environment_client().transport.operations_client.get_operation(
306+
name=operation_name
307+
)
308+
309+
@GoogleBaseHook.fallback_to_default_project_id
310+
async def create_environment(
311+
self,
312+
project_id: str,
313+
region: str,
314+
environment: Union[Environment, Dict],
315+
retry: Union[Retry, _MethodDefault] = DEFAULT,
316+
timeout: Optional[float] = None,
317+
metadata: Sequence[Tuple[str, str]] = (),
318+
) -> AsyncOperation:
319+
"""
320+
Create a new environment.
321+
322+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
323+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
324+
:param environment: The environment to create. This corresponds to the ``environment`` field on the
325+
``request`` instance; if ``request`` is provided, this should not be set.
326+
:param retry: Designation of what errors, if any, should be retried.
327+
:param timeout: The timeout for this request.
328+
:param metadata: Strings which should be sent along with the request as metadata.
329+
"""
330+
client = self.get_environment_client()
331+
return await client.create_environment(
332+
request={'parent': self.get_parent(project_id, region), 'environment': environment},
333+
retry=retry,
334+
timeout=timeout,
335+
metadata=metadata,
336+
)
337+
338+
@GoogleBaseHook.fallback_to_default_project_id
339+
async def delete_environment(
340+
self,
341+
project_id: str,
342+
region: str,
343+
environment_id: str,
344+
retry: Union[Retry, _MethodDefault] = DEFAULT,
345+
timeout: Optional[float] = None,
346+
metadata: Sequence[Tuple[str, str]] = (),
347+
) -> AsyncOperation:
348+
"""
349+
Delete an environment.
350+
351+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
352+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
353+
:param environment_id: Required. The ID of the Google Cloud environment that the service belongs to.
354+
:param retry: Designation of what errors, if any, should be retried.
355+
:param timeout: The timeout for this request.
356+
:param metadata: Strings which should be sent along with the request as metadata.
357+
"""
358+
client = self.get_environment_client()
359+
name = self.get_environment_name(project_id, region, environment_id)
360+
return await client.delete_environment(
361+
request={"name": name}, retry=retry, timeout=timeout, metadata=metadata
362+
)
363+
364+
@GoogleBaseHook.fallback_to_default_project_id
365+
async def update_environment(
366+
self,
367+
project_id: str,
368+
region: str,
369+
environment_id: str,
370+
environment: Union[Environment, Dict],
371+
update_mask: Union[Dict, FieldMask],
372+
retry: Union[Retry, _MethodDefault] = DEFAULT,
373+
timeout: Optional[float] = None,
374+
metadata: Sequence[Tuple[str, str]] = (),
375+
) -> AsyncOperation:
376+
r"""
377+
Update an environment.
378+
379+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
380+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
381+
:param environment_id: Required. The ID of the Google Cloud environment that the service belongs to.
382+
:param environment: A patch environment. Fields specified by the ``updateMask`` will be copied from
383+
the patch environment into the environment under update.
384+
385+
This corresponds to the ``environment`` field on the ``request`` instance; if ``request`` is
386+
provided, this should not be set.
387+
:param update_mask: Required. A comma-separated list of paths, relative to ``Environment``, of fields
388+
to update. If a dict is provided, it must be of the same form as the protobuf message
389+
:class:`~google.protobuf.field_mask_pb2.FieldMask`
390+
:param retry: Designation of what errors, if any, should be retried.
391+
:param timeout: The timeout for this request.
392+
:param metadata: Strings which should be sent along with the request as metadata.
393+
"""
394+
client = self.get_environment_client()
395+
name = self.get_environment_name(project_id, region, environment_id)
396+
397+
return await client.update_environment(
398+
request={"name": name, "environment": environment, "update_mask": update_mask},
399+
retry=retry,
400+
timeout=timeout,
401+
metadata=metadata,
402+
)

β€Žairflow/providers/google/cloud/triggers/cloud_composer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Any, Dict, Optional, Sequence, Tuple, Union
2222

2323
from airflow import AirflowException
24-
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook
24+
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerAsyncHook
2525

2626
try:
2727
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -58,7 +58,7 @@ def __init__(
5858

5959
self.pooling_period_seconds = pooling_period_seconds
6060

61-
self.gcp_hook = CloudComposerHook(
61+
self.gcp_hook = CloudComposerAsyncHook(
6262
gcp_conn_id=self.gcp_conn_id,
6363
impersonation_chain=self.impersonation_chain,
6464
delegate_to=self.delegate_to,
@@ -80,7 +80,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
8080

8181
async def run(self):
8282
while True:
83-
operation = self.gcp_hook.get_operation(operation_name=self.operation_name)
83+
operation = await self.gcp_hook.get_operation(operation_name=self.operation_name)
8484
if operation.done:
8585
break
8686
elif operation.error.message:

β€Žtests/providers/google/cloud/hooks/test_cloud_composer.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
import unittest
2121
from unittest import mock
2222

23+
import pytest
2324
from google.api_core.gapic_v1.method import DEFAULT
2425

25-
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook
26+
from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerAsyncHook, CloudComposerHook
2627

2728
TEST_GCP_REGION = "global"
2829
TEST_GCP_PROJECT = "test-project"
@@ -193,3 +194,81 @@ def test_list_image_versions(self, mock_client) -> None:
193194
timeout=TEST_TIMEOUT,
194195
metadata=TEST_METADATA,
195196
)
197+
198+
199+
class TestCloudComposerAsyncHook(unittest.TestCase):
200+
def setUp(
201+
self,
202+
) -> None:
203+
with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_init):
204+
self.hook = CloudComposerAsyncHook(gcp_conn_id="test")
205+
206+
@pytest.mark.asyncio
207+
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
208+
async def test_create_environment(self, mock_client) -> None:
209+
await self.hook.create_environment(
210+
project_id=TEST_GCP_PROJECT,
211+
region=TEST_GCP_REGION,
212+
environment=TEST_ENVIRONMENT,
213+
retry=TEST_RETRY,
214+
timeout=TEST_TIMEOUT,
215+
metadata=TEST_METADATA,
216+
)
217+
mock_client.assert_called_once()
218+
mock_client.return_value.create_environment.assert_called_once_with(
219+
request={
220+
'parent': self.hook.get_parent(TEST_GCP_PROJECT, TEST_GCP_REGION),
221+
'environment': TEST_ENVIRONMENT,
222+
},
223+
retry=TEST_RETRY,
224+
timeout=TEST_TIMEOUT,
225+
metadata=TEST_METADATA,
226+
)
227+
228+
@pytest.mark.asyncio
229+
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
230+
async def test_delete_environment(self, mock_client) -> None:
231+
await self.hook.delete_environment(
232+
project_id=TEST_GCP_PROJECT,
233+
region=TEST_GCP_REGION,
234+
environment_id=TEST_ENVIRONMENT_ID,
235+
retry=TEST_RETRY,
236+
timeout=TEST_TIMEOUT,
237+
metadata=TEST_METADATA,
238+
)
239+
mock_client.assert_called_once()
240+
mock_client.return_value.delete_environment.assert_called_once_with(
241+
request={
242+
"name": self.hook.get_environment_name(TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID)
243+
},
244+
retry=TEST_RETRY,
245+
timeout=TEST_TIMEOUT,
246+
metadata=TEST_METADATA,
247+
)
248+
249+
@pytest.mark.asyncio
250+
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
251+
async def test_update_environment(self, mock_client) -> None:
252+
await self.hook.update_environment(
253+
project_id=TEST_GCP_PROJECT,
254+
region=TEST_GCP_REGION,
255+
environment_id=TEST_ENVIRONMENT_ID,
256+
environment=TEST_UPDATED_ENVIRONMENT,
257+
update_mask=TEST_UPDATE_MASK,
258+
retry=TEST_RETRY,
259+
timeout=TEST_TIMEOUT,
260+
metadata=TEST_METADATA,
261+
)
262+
mock_client.assert_called_once()
263+
mock_client.return_value.update_environment.assert_called_once_with(
264+
request={
265+
"name": self.hook.get_environment_name(
266+
TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID
267+
),
268+
"environment": TEST_UPDATED_ENVIRONMENT,
269+
"update_mask": TEST_UPDATE_MASK,
270+
},
271+
retry=TEST_RETRY,
272+
timeout=TEST_TIMEOUT,
273+
metadata=TEST_METADATA,
274+
)

β€Žtests/providers/google/cloud/operators/test_cloud_composer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_execute(self, mock_hook, to_dict_mode) -> None:
9494

9595
@mock.patch(COMPOSER_STRING.format("Environment.to_dict"))
9696
@mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
97-
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
97+
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
9898
def test_execute_deferrable(self, mock_trigger_hook, mock_hook, to_dict_mode):
9999
op = CloudComposerCreateEnvironmentOperator(
100100
task_id=TASK_ID,
@@ -145,7 +145,7 @@ def test_execute(self, mock_hook) -> None:
145145
)
146146

147147
@mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
148-
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
148+
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
149149
def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
150150
op = CloudComposerDeleteEnvironmentOperator(
151151
task_id=TASK_ID,
@@ -200,7 +200,7 @@ def test_execute(self, mock_hook, to_dict_mode) -> None:
200200

201201
@mock.patch(COMPOSER_STRING.format("Environment.to_dict"))
202202
@mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
203-
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
203+
@mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
204204
def test_execute_deferrable(self, mock_trigger_hook, mock_hook, to_dict_mode):
205205
op = CloudComposerUpdateEnvironmentOperator(
206206
task_id=TASK_ID,

0 commit comments

Comments
 (0)