Skip to content

Commit b118916

Browse files
authored
[AIRFLOW-7069] Fix cloudsql system tests (#7770)
1 parent c24f841 commit b118916

File tree

9 files changed

+383
-360
lines changed

9 files changed

+383
-360
lines changed

β€Žairflow/providers/google/cloud/example_dags/example_cloud_sql.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@
202202
sql_instance_read_replica_create = CloudSQLCreateInstanceOperator(
203203
project_id=GCP_PROJECT_ID,
204204
body=read_replica_body,
205-
instance=INSTANCE_NAME2,
205+
instance=READ_REPLICA_NAME,
206206
task_id='sql_instance_read_replica_create'
207207
)
208208

@@ -217,13 +217,14 @@
217217
instance=INSTANCE_NAME,
218218
task_id='sql_instance_patch_task'
219219
)
220+
# [END howto_operator_cloudsql_patch]
220221

221222
sql_instance_patch_task2 = CloudSQLInstancePatchOperator(
223+
project_id=GCP_PROJECT_ID,
222224
body=patch_body,
223225
instance=INSTANCE_NAME,
224226
task_id='sql_instance_patch_task2'
225227
)
226-
# [END howto_operator_cloudsql_patch]
227228

228229
# [START howto_operator_cloudsql_db_create]
229230
sql_db_create_task = CloudSQLCreateInstanceDatabaseOperator(

β€Žairflow/providers/google/cloud/hooks/base.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
"""
2020
This module contains a Google Cloud API base hook.
2121
"""
22-
2322
import functools
2423
import json
2524
import logging
@@ -39,6 +38,7 @@
3938
from google.api_core.gapic_v1.client_info import ClientInfo
4039
from google.auth import _cloud_sdk
4140
from google.auth.environment_vars import CREDENTIALS
41+
from googleapiclient.errors import HttpError
4242
from googleapiclient.http import set_user_agent
4343

4444
from airflow import version
@@ -92,13 +92,32 @@ def is_soft_quota_exception(exception: Exception):
9292
return False
9393

9494

95+
def is_operation_in_progress_exception(exception: Exception):
96+
"""
97+
Some of the calls return 429 (too many requests!) or 409 errors (Conflict)
98+
in case of operation in progress.
99+
100+
* Google Cloud SQL
101+
"""
102+
if isinstance(exception, HttpError):
103+
return exception.resp.status == 429 or exception.resp.status == 409
104+
return False
105+
106+
95107
class retry_if_temporary_quota(tenacity.retry_if_exception): # pylint: disable=invalid-name
96108
"""Retries if there was an exception for exceeding the temporary quote limit."""
97109

98110
def __init__(self):
99111
super().__init__(is_soft_quota_exception)
100112

101113

114+
class retry_if_operation_in_progress(tenacity.retry_if_exception): # pylint: disable=invalid-name
115+
"""Retries if there was an exception for exceeding the temporary quote limit."""
116+
117+
def __init__(self):
118+
super().__init__(is_operation_in_progress_exception)
119+
120+
102121
RT = TypeVar('RT') # pylint: disable=invalid-name
103122

104123

@@ -295,7 +314,7 @@ def scopes(self) -> Sequence[str]:
295314
@staticmethod
296315
def quota_retry(*args, **kwargs) -> Callable:
297316
"""
298-
A decorator who provides a mechanism to repeat requests in response to exceeding a temporary quote
317+
A decorator that provides a mechanism to repeat requests in response to exceeding a temporary quote
299318
limit.
300319
"""
301320
def decorator(fun: Callable):
@@ -311,6 +330,26 @@ def decorator(fun: Callable):
311330
)(fun)
312331
return decorator
313332

333+
@staticmethod
334+
def operation_in_progress_retry(*args, **kwargs) -> Callable:
335+
"""
336+
A decorator that provides a mechanism to repeat requests in response to
337+
operation in progress (HTTP 409)
338+
limit.
339+
"""
340+
def decorator(fun: Callable):
341+
default_kwargs = {
342+
'wait': tenacity.wait_exponential(multiplier=1, max=300),
343+
'retry': retry_if_operation_in_progress(),
344+
'before': tenacity.before_log(log, logging.DEBUG),
345+
'after': tenacity.after_log(log, logging.DEBUG),
346+
}
347+
default_kwargs.update(**kwargs)
348+
return tenacity.retry(
349+
*args, **default_kwargs
350+
)(fun)
351+
return decorator
352+
314353
@staticmethod
315354
def fallback_to_default_project_id(func: Callable[..., RT]) -> Callable[..., RT]:
316355
"""

β€Žairflow/providers/google/cloud/hooks/cloud_sql.py

Lines changed: 33 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
UNIX_PATH_MAX = 108
5757

5858
# Time to sleep between active checks of the operation results
59-
TIME_TO_SLEEP_IN_SECONDS = 1
59+
TIME_TO_SLEEP_IN_SECONDS = 20
6060

6161

6262
class CloudSqlOperationStatus:
@@ -113,14 +113,13 @@ def get_instance(self, instance: str, project_id: Optional[str] = None) -> Dict:
113113
:return: A Cloud SQL instance resource.
114114
:rtype: dict
115115
"""
116-
if not project_id:
117-
raise ValueError("The project_id should be set")
118-
return self.get_conn().instances().get( # pylint: disable=no-member
116+
return self.get_conn().instances().get( # noqa # pylint: disable=no-member
119117
project=project_id,
120118
instance=instance
121119
).execute(num_retries=self.num_retries)
122120

123121
@CloudBaseHook.fallback_to_default_project_id
122+
@CloudBaseHook.operation_in_progress_retry()
124123
def create_instance(self, body: Dict, project_id: Optional[str] = None) -> None:
125124
"""
126125
Creates a new Cloud SQL instance.
@@ -133,17 +132,16 @@ def create_instance(self, body: Dict, project_id: Optional[str] = None) -> None:
133132
:type project_id: str
134133
:return: None
135134
"""
136-
if not project_id:
137-
raise ValueError("The project_id should be set")
138-
response = self.get_conn().instances().insert( # pylint: disable=no-member
135+
response = self.get_conn().instances().insert( # noqa # pylint: disable=no-member
139136
project=project_id,
140137
body=body
141138
).execute(num_retries=self.num_retries)
142139
operation_name = response["name"]
143-
self._wait_for_operation_to_complete(project_id=project_id,
140+
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
144141
operation_name=operation_name)
145142

146143
@CloudBaseHook.fallback_to_default_project_id
144+
@CloudBaseHook.operation_in_progress_retry()
147145
def patch_instance(self, body: Dict, instance: str, project_id: Optional[str] = None) -> None:
148146
"""
149147
Updates settings of a Cloud SQL instance.
@@ -161,18 +159,17 @@ def patch_instance(self, body: Dict, instance: str, project_id: Optional[str] =
161159
:type project_id: str
162160
:return: None
163161
"""
164-
if not project_id:
165-
raise ValueError("The project_id should be set")
166-
response = self.get_conn().instances().patch( # pylint: disable=no-member
162+
response = self.get_conn().instances().patch( # noqa # pylint: disable=no-member
167163
project=project_id,
168164
instance=instance,
169165
body=body
170166
).execute(num_retries=self.num_retries)
171167
operation_name = response["name"]
172-
self._wait_for_operation_to_complete(project_id=project_id,
168+
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
173169
operation_name=operation_name)
174170

175171
@CloudBaseHook.fallback_to_default_project_id
172+
@CloudBaseHook.operation_in_progress_retry()
176173
def delete_instance(self, instance: str, project_id: Optional[str] = None) -> None:
177174
"""
178175
Deletes a Cloud SQL instance.
@@ -184,14 +181,12 @@ def delete_instance(self, instance: str, project_id: Optional[str] = None) -> No
184181
:type instance: str
185182
:return: None
186183
"""
187-
if not project_id:
188-
raise ValueError("The project_id should be set")
189-
response = self.get_conn().instances().delete( # pylint: disable=no-member
184+
response = self.get_conn().instances().delete( # noqa # pylint: disable=no-member
190185
project=project_id,
191186
instance=instance,
192187
).execute(num_retries=self.num_retries)
193188
operation_name = response["name"]
194-
self._wait_for_operation_to_complete(project_id=project_id,
189+
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
195190
operation_name=operation_name)
196191

197192
@CloudBaseHook.fallback_to_default_project_id
@@ -210,15 +205,14 @@ def get_database(self, instance: str, database: str, project_id: Optional[str] =
210205
https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/databases#resource.
211206
:rtype: dict
212207
"""
213-
if not project_id:
214-
raise ValueError("The project_id should be set")
215-
return self.get_conn().databases().get( # pylint: disable=no-member
208+
return self.get_conn().databases().get( # noqa # pylint: disable=no-member
216209
project=project_id,
217210
instance=instance,
218211
database=database
219212
).execute(num_retries=self.num_retries)
220213

221214
@CloudBaseHook.fallback_to_default_project_id
215+
@CloudBaseHook.operation_in_progress_retry()
222216
def create_database(self, instance: str, body: Dict, project_id: Optional[str] = None) -> None:
223217
"""
224218
Creates a new database inside a Cloud SQL instance.
@@ -233,18 +227,17 @@ def create_database(self, instance: str, body: Dict, project_id: Optional[str] =
233227
:type project_id: str
234228
:return: None
235229
"""
236-
if not project_id:
237-
raise ValueError("The project_id should be set")
238-
response = self.get_conn().databases().insert( # pylint: disable=no-member
230+
response = self.get_conn().databases().insert( # noqa # pylint: disable=no-member
239231
project=project_id,
240232
instance=instance,
241233
body=body
242234
).execute(num_retries=self.num_retries)
243235
operation_name = response["name"]
244-
self._wait_for_operation_to_complete(project_id=project_id,
236+
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
245237
operation_name=operation_name)
246238

247239
@CloudBaseHook.fallback_to_default_project_id
240+
@CloudBaseHook.operation_in_progress_retry()
248241
def patch_database(
249242
self,
250243
instance: str,
@@ -270,19 +263,18 @@ def patch_database(
270263
:type project_id: str
271264
:return: None
272265
"""
273-
if not project_id:
274-
raise ValueError("The project_id should be set")
275-
response = self.get_conn().databases().patch( # pylint: disable=no-member
266+
response = self.get_conn().databases().patch( # noqa # pylint: disable=no-member
276267
project=project_id,
277268
instance=instance,
278269
database=database,
279270
body=body
280271
).execute(num_retries=self.num_retries)
281272
operation_name = response["name"]
282-
self._wait_for_operation_to_complete(project_id=project_id,
273+
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
283274
operation_name=operation_name)
284275

285276
@CloudBaseHook.fallback_to_default_project_id
277+
@CloudBaseHook.operation_in_progress_retry()
286278
def delete_database(self, instance: str, database: str, project_id: Optional[str] = None) -> None:
287279
"""
288280
Deletes a database from a Cloud SQL instance.
@@ -296,18 +288,17 @@ def delete_database(self, instance: str, database: str, project_id: Optional[str
296288
:type project_id: str
297289
:return: None
298290
"""
299-
if not project_id:
300-
raise ValueError("The project_id should be set")
301-
response = self.get_conn().databases().delete( # pylint: disable=no-member
291+
response = self.get_conn().databases().delete( # noqa # pylint: disable=no-member
302292
project=project_id,
303293
instance=instance,
304294
database=database
305295
).execute(num_retries=self.num_retries)
306296
operation_name = response["name"]
307-
self._wait_for_operation_to_complete(project_id=project_id,
297+
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
308298
operation_name=operation_name)
309299

310300
@CloudBaseHook.fallback_to_default_project_id
301+
@CloudBaseHook.operation_in_progress_retry()
311302
def export_instance(self, instance: str, body: Dict, project_id: Optional[str] = None) -> None:
312303
"""
313304
Exports data from a Cloud SQL instance to a Cloud Storage bucket as a SQL dump
@@ -324,21 +315,14 @@ def export_instance(self, instance: str, body: Dict, project_id: Optional[str] =
324315
:type project_id: str
325316
:return: None
326317
"""
327-
if not project_id:
328-
raise ValueError("The project_id should be set")
329-
try:
330-
response = self.get_conn().instances().export( # pylint: disable=no-member
331-
project=project_id,
332-
instance=instance,
333-
body=body
334-
).execute(num_retries=self.num_retries)
335-
operation_name = response["name"]
336-
self._wait_for_operation_to_complete(project_id=project_id,
337-
operation_name=operation_name)
338-
except HttpError as ex:
339-
raise AirflowException(
340-
'Exporting instance {} failed: {}'.format(instance, ex.content)
341-
)
318+
response = self.get_conn().instances().export( # noqa # pylint: disable=no-member
319+
project=project_id,
320+
instance=instance,
321+
body=body
322+
).execute(num_retries=self.num_retries)
323+
operation_name = response["name"]
324+
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
325+
operation_name=operation_name)
342326

343327
@CloudBaseHook.fallback_to_default_project_id
344328
def import_instance(self, instance: str, body: Dict, project_id: Optional[str] = None) -> None:
@@ -357,16 +341,14 @@ def import_instance(self, instance: str, body: Dict, project_id: Optional[str] =
357341
:type project_id: str
358342
:return: None
359343
"""
360-
if not project_id:
361-
raise ValueError("The project_id should be set")
362344
try:
363-
response = self.get_conn().instances().import_( # pylint: disable=no-member
345+
response = self.get_conn().instances().import_( # noqa # pylint: disable=no-member
364346
project=project_id,
365347
instance=instance,
366348
body=body
367349
).execute(num_retries=self.num_retries)
368350
operation_name = response["name"]
369-
self._wait_for_operation_to_complete(project_id=project_id,
351+
self._wait_for_operation_to_complete(project_id=project_id, # type: ignore
370352
operation_name=operation_name)
371353
except HttpError as ex:
372354
raise AirflowException(
@@ -388,7 +370,7 @@ def _wait_for_operation_to_complete(self, project_id: str, operation_name: str)
388370
raise ValueError("The project_id should be set")
389371
service = self.get_conn()
390372
while True:
391-
operation_response = service.operations().get( # pylint: disable=no-member
373+
operation_response = service.operations().get( # noqa # pylint: disable=no-member
392374
project=project_id,
393375
operation=operation_name,
394376
).execute(num_retries=self.num_retries)

0 commit comments

Comments
 (0)