Skip to content

Commit f938cd4

Browse files
authored
Add deferrable big query operators and sensors (#26156)
This PR donates the following big query deferrable operators and sensors developed in [astronomer-providers](https://github.com/astronomer/astronomer-providers) repo to apache airflow. - `BigQueryInsertJobAsyncOperator` - `BigQueryCheckAsyncOperator` - `BigQueryGetDataAsyncOperator` - `BigQueryIntervalCheckAsyncOperator` - `BigQueryValueCheckAsyncOperator` - `BigQueryTableExistenceAsyncSensor`
1 parent 9cf6f6a commit f938cd4

File tree

16 files changed

+3582
-26
lines changed

16 files changed

+3582
-26
lines changed

β€Žairflow/providers/google/cloud/hooks/bigquery.py

Lines changed: 255 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
import warnings
3030
from copy import deepcopy
3131
from datetime import datetime, timedelta
32-
from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union
32+
from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union, cast
3333

34+
from aiohttp import ClientSession as ClientSession
35+
from gcloud.aio.bigquery import Job, Table as Table_async
3436
from google.api_core.retry import Retry
3537
from google.cloud.bigquery import (
3638
DEFAULT_RETRY,
@@ -49,12 +51,13 @@
4951
from pandas import DataFrame
5052
from pandas_gbq import read_gbq
5153
from pandas_gbq.gbq import GbqConnector # noqa
54+
from requests import Session
5255
from sqlalchemy import create_engine
5356

5457
from airflow.exceptions import AirflowException
5558
from airflow.providers.common.sql.hooks.sql import DbApiHook
5659
from airflow.providers.google.common.consts import CLIENT_INFO
57-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
60+
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
5861
from airflow.utils.helpers import convert_camel_to_snake
5962
from airflow.utils.log.logging_mixin import LoggingMixin
6063

@@ -2305,7 +2308,6 @@ def __init__(
23052308
num_retries: int = 5,
23062309
labels: Optional[Dict] = None,
23072310
) -> None:
2308-
23092311
super().__init__()
23102312
self.service = service
23112313
self.project_id = project_id
@@ -2870,7 +2872,6 @@ def _bq_cast(string_field: str, bq_type: str) -> Union[None, int, float, bool, s
28702872
def split_tablename(
28712873
table_input: str, default_project_id: str, var_name: Optional[str] = None
28722874
) -> Tuple[str, str, str]:
2873-
28742875
if '.' not in table_input:
28752876
raise ValueError(f'Expected table name in the format of <dataset>.<table>. Got: {table_input}')
28762877

@@ -3010,3 +3011,253 @@ def _format_schema_for_description(schema: Dict) -> List:
30103011
)
30113012
description.append(field_description)
30123013
return description
3014+
3015+
3016+
class BigQueryAsyncHook(GoogleBaseAsyncHook):
3017+
"""Uses gcloud-aio library to retrieve Job details"""
3018+
3019+
sync_hook_class = BigQueryHook
3020+
3021+
async def get_job_instance(
3022+
self, project_id: Optional[str], job_id: Optional[str], session: ClientSession
3023+
) -> Job:
3024+
"""Get the specified job resource by job ID and project ID."""
3025+
with await self.service_file_as_context() as f:
3026+
return Job(job_id=job_id, project=project_id, service_file=f, session=cast(Session, session))
3027+
3028+
async def get_job_status(
3029+
self,
3030+
job_id: Optional[str],
3031+
project_id: Optional[str] = None,
3032+
) -> Optional[str]:
3033+
"""
3034+
Polls for job status asynchronously using gcloud-aio.
3035+
3036+
Note that an OSError is raised when Job results are still pending.
3037+
Exception means that Job finished with errors
3038+
"""
3039+
async with ClientSession() as s:
3040+
try:
3041+
self.log.info("Executing get_job_status...")
3042+
job_client = await self.get_job_instance(project_id, job_id, s)
3043+
job_status_response = await job_client.result(cast(Session, s))
3044+
if job_status_response:
3045+
job_status = "success"
3046+
except OSError:
3047+
job_status = "pending"
3048+
except Exception as e:
3049+
self.log.info("Query execution finished with errors...")
3050+
job_status = str(e)
3051+
return job_status
3052+
3053+
async def get_job_output(
3054+
self,
3055+
job_id: Optional[str],
3056+
project_id: Optional[str] = None,
3057+
) -> Dict[str, Any]:
3058+
"""Get the big query job output for the given job id asynchronously using gcloud-aio."""
3059+
async with ClientSession() as session:
3060+
self.log.info("Executing get_job_output..")
3061+
job_client = await self.get_job_instance(project_id, job_id, session)
3062+
job_query_response = await job_client.get_query_results(cast(Session, session))
3063+
return job_query_response
3064+
3065+
def get_records(self, query_results: Dict[str, Any]) -> List[Any]:
3066+
"""
3067+
Given the output query response from gcloud-aio bigquery, convert the response to records.
3068+
3069+
:param query_results: the results from a SQL query
3070+
"""
3071+
buffer = []
3072+
if "rows" in query_results and query_results["rows"]:
3073+
rows = query_results["rows"]
3074+
for dict_row in rows:
3075+
typed_row = [vs["v"] for vs in dict_row["f"]]
3076+
buffer.append(typed_row)
3077+
return buffer
3078+
3079+
def value_check(
3080+
self,
3081+
sql: str,
3082+
pass_value: Any,
3083+
records: List[Any],
3084+
tolerance: Optional[float] = None,
3085+
) -> None:
3086+
"""
3087+
Match a single query resulting row and tolerance with pass_value
3088+
3089+
:return: If Match fail, we throw an AirflowException.
3090+
"""
3091+
if not records:
3092+
raise AirflowException("The query returned None")
3093+
pass_value_conv = self._convert_to_float_if_possible(pass_value)
3094+
is_numeric_value_check = isinstance(pass_value_conv, float)
3095+
tolerance_pct_str = str(tolerance * 100) + "%" if tolerance else None
3096+
3097+
error_msg = (
3098+
"Test failed.\nPass value:{pass_value_conv}\n"
3099+
"Tolerance:{tolerance_pct_str}\n"
3100+
"Query:\n{sql}\nResults:\n{records!s}"
3101+
).format(
3102+
pass_value_conv=pass_value_conv,
3103+
tolerance_pct_str=tolerance_pct_str,
3104+
sql=sql,
3105+
records=records,
3106+
)
3107+
3108+
if not is_numeric_value_check:
3109+
tests = [str(record) == pass_value_conv for record in records]
3110+
else:
3111+
try:
3112+
numeric_records = [float(record) for record in records]
3113+
except (ValueError, TypeError):
3114+
raise AirflowException(f"Converting a result to float failed.\n{error_msg}")
3115+
tests = self._get_numeric_matches(numeric_records, pass_value_conv, tolerance)
3116+
3117+
if not all(tests):
3118+
raise AirflowException(error_msg)
3119+
3120+
@staticmethod
3121+
def _get_numeric_matches(
3122+
records: List[float], pass_value: Any, tolerance: Optional[float] = None
3123+
) -> List[bool]:
3124+
"""
3125+
A helper function to match numeric pass_value, tolerance with records value
3126+
3127+
:param records: List of value to match against
3128+
:param pass_value: Expected value
3129+
:param tolerance: Allowed tolerance for match to succeed
3130+
"""
3131+
if tolerance:
3132+
return [
3133+
pass_value * (1 - tolerance) <= record <= pass_value * (1 + tolerance) for record in records
3134+
]
3135+
3136+
return [record == pass_value for record in records]
3137+
3138+
@staticmethod
3139+
def _convert_to_float_if_possible(s: Any) -> Any:
3140+
"""
3141+
A small helper function to convert a string to a numeric value if appropriate
3142+
3143+
:param s: the string to be converted
3144+
"""
3145+
try:
3146+
return float(s)
3147+
except (ValueError, TypeError):
3148+
return s
3149+
3150+
def interval_check(
3151+
self,
3152+
row1: Optional[str],
3153+
row2: Optional[str],
3154+
metrics_thresholds: Dict[str, Any],
3155+
ignore_zero: bool,
3156+
ratio_formula: str,
3157+
) -> None:
3158+
"""
3159+
Checks that the values of metrics given as SQL expressions are within a certain tolerance
3160+
3161+
:param row1: first resulting row of a query execution job for first SQL query
3162+
:param row2: first resulting row of a query execution job for second SQL query
3163+
:param metrics_thresholds: a dictionary of ratios indexed by metrics, for
3164+
example 'COUNT(*)': 1.5 would require a 50 percent or less difference
3165+
between the current day, and the prior days_back.
3166+
:param ignore_zero: whether we should ignore zero metrics
3167+
:param ratio_formula: which formula to use to compute the ratio between
3168+
the two metrics. Assuming cur is the metric of today and ref is
3169+
the metric to today - days_back.
3170+
max_over_min: computes max(cur, ref) / min(cur, ref)
3171+
relative_diff: computes abs(cur-ref) / ref
3172+
"""
3173+
if not row2:
3174+
raise AirflowException("The second SQL query returned None")
3175+
if not row1:
3176+
raise AirflowException("The first SQL query returned None")
3177+
3178+
ratio_formulas = {
3179+
"max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref),
3180+
"relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref,
3181+
}
3182+
3183+
metrics_sorted = sorted(metrics_thresholds.keys())
3184+
3185+
current = dict(zip(metrics_sorted, row1))
3186+
reference = dict(zip(metrics_sorted, row2))
3187+
ratios: Dict[str, Any] = {}
3188+
test_results: Dict[str, Any] = {}
3189+
3190+
for metric in metrics_sorted:
3191+
cur = float(current[metric])
3192+
ref = float(reference[metric])
3193+
threshold = float(metrics_thresholds[metric])
3194+
if cur == 0 or ref == 0:
3195+
ratios[metric] = None
3196+
test_results[metric] = ignore_zero
3197+
else:
3198+
ratios[metric] = ratio_formulas[ratio_formula](
3199+
float(current[metric]), float(reference[metric])
3200+
)
3201+
test_results[metric] = float(ratios[metric]) < threshold
3202+
3203+
self.log.info(
3204+
(
3205+
"Current metric for %s: %s\n"
3206+
"Past metric for %s: %s\n"
3207+
"Ratio for %s: %s\n"
3208+
"Threshold: %s\n"
3209+
),
3210+
metric,
3211+
cur,
3212+
metric,
3213+
ref,
3214+
metric,
3215+
ratios[metric],
3216+
threshold,
3217+
)
3218+
3219+
if not all(test_results.values()):
3220+
failed_tests = [metric for metric, value in test_results.items() if not value]
3221+
self.log.warning(
3222+
"The following %s tests out of %s failed:",
3223+
len(failed_tests),
3224+
len(metrics_sorted),
3225+
)
3226+
for k in failed_tests:
3227+
self.log.warning(
3228+
"'%s' check failed. %s is above %s",
3229+
k,
3230+
ratios[k],
3231+
metrics_thresholds[k],
3232+
)
3233+
raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}")
3234+
3235+
self.log.info("All tests have passed")
3236+
3237+
3238+
class BigQueryTableAsyncHook(GoogleBaseAsyncHook):
3239+
"""Class to get async hook for Bigquery Table Async"""
3240+
3241+
sync_hook_class = BigQueryHook
3242+
3243+
async def get_table_client(
3244+
self, dataset: str, table_id: str, project_id: str, session: ClientSession
3245+
) -> Table_async:
3246+
"""
3247+
Returns a Google Big Query Table object.
3248+
3249+
:param dataset: The name of the dataset in which to look for the table storage bucket.
3250+
:param table_id: The name of the table to check the existence of.
3251+
:param project_id: The Google cloud project in which to look for the table.
3252+
The connection supplied to the hook must provide
3253+
access to the specified project.
3254+
:param session: aiohttp ClientSession
3255+
"""
3256+
with await self.service_file_as_context() as file:
3257+
return Table_async(
3258+
dataset_name=dataset,
3259+
table_name=table_id,
3260+
project=project_id,
3261+
service_file=file,
3262+
session=cast(Session, session),
3263+
)

0 commit comments

Comments
 (0)