|
29 | 29 | import warnings
|
30 | 30 | from copy import deepcopy
|
31 | 31 | from datetime import datetime, timedelta
|
32 |
| -from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union |
| 32 | +from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Optional, Sequence, Tuple, Type, Union, cast |
33 | 33 |
|
| 34 | +from aiohttp import ClientSession as ClientSession |
| 35 | +from gcloud.aio.bigquery import Job, Table as Table_async |
34 | 36 | from google.api_core.retry import Retry
|
35 | 37 | from google.cloud.bigquery import (
|
36 | 38 | DEFAULT_RETRY,
|
|
49 | 51 | from pandas import DataFrame
|
50 | 52 | from pandas_gbq import read_gbq
|
51 | 53 | from pandas_gbq.gbq import GbqConnector # noqa
|
| 54 | +from requests import Session |
52 | 55 | from sqlalchemy import create_engine
|
53 | 56 |
|
54 | 57 | from airflow.exceptions import AirflowException
|
55 | 58 | from airflow.providers.common.sql.hooks.sql import DbApiHook
|
56 | 59 | from airflow.providers.google.common.consts import CLIENT_INFO
|
57 |
| -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook |
| 60 | +from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook |
58 | 61 | from airflow.utils.helpers import convert_camel_to_snake
|
59 | 62 | from airflow.utils.log.logging_mixin import LoggingMixin
|
60 | 63 |
|
@@ -2305,7 +2308,6 @@ def __init__(
|
2305 | 2308 | num_retries: int = 5,
|
2306 | 2309 | labels: Optional[Dict] = None,
|
2307 | 2310 | ) -> None:
|
2308 |
| - |
2309 | 2311 | super().__init__()
|
2310 | 2312 | self.service = service
|
2311 | 2313 | self.project_id = project_id
|
@@ -2870,7 +2872,6 @@ def _bq_cast(string_field: str, bq_type: str) -> Union[None, int, float, bool, s
|
2870 | 2872 | def split_tablename(
|
2871 | 2873 | table_input: str, default_project_id: str, var_name: Optional[str] = None
|
2872 | 2874 | ) -> Tuple[str, str, str]:
|
2873 |
| - |
2874 | 2875 | if '.' not in table_input:
|
2875 | 2876 | raise ValueError(f'Expected table name in the format of <dataset>.<table>. Got: {table_input}')
|
2876 | 2877 |
|
@@ -3010,3 +3011,253 @@ def _format_schema_for_description(schema: Dict) -> List:
|
3010 | 3011 | )
|
3011 | 3012 | description.append(field_description)
|
3012 | 3013 | return description
|
| 3014 | + |
| 3015 | + |
| 3016 | +class BigQueryAsyncHook(GoogleBaseAsyncHook): |
| 3017 | + """Uses gcloud-aio library to retrieve Job details""" |
| 3018 | + |
| 3019 | + sync_hook_class = BigQueryHook |
| 3020 | + |
| 3021 | + async def get_job_instance( |
| 3022 | + self, project_id: Optional[str], job_id: Optional[str], session: ClientSession |
| 3023 | + ) -> Job: |
| 3024 | + """Get the specified job resource by job ID and project ID.""" |
| 3025 | + with await self.service_file_as_context() as f: |
| 3026 | + return Job(job_id=job_id, project=project_id, service_file=f, session=cast(Session, session)) |
| 3027 | + |
| 3028 | + async def get_job_status( |
| 3029 | + self, |
| 3030 | + job_id: Optional[str], |
| 3031 | + project_id: Optional[str] = None, |
| 3032 | + ) -> Optional[str]: |
| 3033 | + """ |
| 3034 | + Polls for job status asynchronously using gcloud-aio. |
| 3035 | +
|
| 3036 | + Note that an OSError is raised when Job results are still pending. |
| 3037 | + Exception means that Job finished with errors |
| 3038 | + """ |
| 3039 | + async with ClientSession() as s: |
| 3040 | + try: |
| 3041 | + self.log.info("Executing get_job_status...") |
| 3042 | + job_client = await self.get_job_instance(project_id, job_id, s) |
| 3043 | + job_status_response = await job_client.result(cast(Session, s)) |
| 3044 | + if job_status_response: |
| 3045 | + job_status = "success" |
| 3046 | + except OSError: |
| 3047 | + job_status = "pending" |
| 3048 | + except Exception as e: |
| 3049 | + self.log.info("Query execution finished with errors...") |
| 3050 | + job_status = str(e) |
| 3051 | + return job_status |
| 3052 | + |
| 3053 | + async def get_job_output( |
| 3054 | + self, |
| 3055 | + job_id: Optional[str], |
| 3056 | + project_id: Optional[str] = None, |
| 3057 | + ) -> Dict[str, Any]: |
| 3058 | + """Get the big query job output for the given job id asynchronously using gcloud-aio.""" |
| 3059 | + async with ClientSession() as session: |
| 3060 | + self.log.info("Executing get_job_output..") |
| 3061 | + job_client = await self.get_job_instance(project_id, job_id, session) |
| 3062 | + job_query_response = await job_client.get_query_results(cast(Session, session)) |
| 3063 | + return job_query_response |
| 3064 | + |
| 3065 | + def get_records(self, query_results: Dict[str, Any]) -> List[Any]: |
| 3066 | + """ |
| 3067 | + Given the output query response from gcloud-aio bigquery, convert the response to records. |
| 3068 | +
|
| 3069 | + :param query_results: the results from a SQL query |
| 3070 | + """ |
| 3071 | + buffer = [] |
| 3072 | + if "rows" in query_results and query_results["rows"]: |
| 3073 | + rows = query_results["rows"] |
| 3074 | + for dict_row in rows: |
| 3075 | + typed_row = [vs["v"] for vs in dict_row["f"]] |
| 3076 | + buffer.append(typed_row) |
| 3077 | + return buffer |
| 3078 | + |
| 3079 | + def value_check( |
| 3080 | + self, |
| 3081 | + sql: str, |
| 3082 | + pass_value: Any, |
| 3083 | + records: List[Any], |
| 3084 | + tolerance: Optional[float] = None, |
| 3085 | + ) -> None: |
| 3086 | + """ |
| 3087 | + Match a single query resulting row and tolerance with pass_value |
| 3088 | +
|
| 3089 | + :return: If Match fail, we throw an AirflowException. |
| 3090 | + """ |
| 3091 | + if not records: |
| 3092 | + raise AirflowException("The query returned None") |
| 3093 | + pass_value_conv = self._convert_to_float_if_possible(pass_value) |
| 3094 | + is_numeric_value_check = isinstance(pass_value_conv, float) |
| 3095 | + tolerance_pct_str = str(tolerance * 100) + "%" if tolerance else None |
| 3096 | + |
| 3097 | + error_msg = ( |
| 3098 | + "Test failed.\nPass value:{pass_value_conv}\n" |
| 3099 | + "Tolerance:{tolerance_pct_str}\n" |
| 3100 | + "Query:\n{sql}\nResults:\n{records!s}" |
| 3101 | + ).format( |
| 3102 | + pass_value_conv=pass_value_conv, |
| 3103 | + tolerance_pct_str=tolerance_pct_str, |
| 3104 | + sql=sql, |
| 3105 | + records=records, |
| 3106 | + ) |
| 3107 | + |
| 3108 | + if not is_numeric_value_check: |
| 3109 | + tests = [str(record) == pass_value_conv for record in records] |
| 3110 | + else: |
| 3111 | + try: |
| 3112 | + numeric_records = [float(record) for record in records] |
| 3113 | + except (ValueError, TypeError): |
| 3114 | + raise AirflowException(f"Converting a result to float failed.\n{error_msg}") |
| 3115 | + tests = self._get_numeric_matches(numeric_records, pass_value_conv, tolerance) |
| 3116 | + |
| 3117 | + if not all(tests): |
| 3118 | + raise AirflowException(error_msg) |
| 3119 | + |
| 3120 | + @staticmethod |
| 3121 | + def _get_numeric_matches( |
| 3122 | + records: List[float], pass_value: Any, tolerance: Optional[float] = None |
| 3123 | + ) -> List[bool]: |
| 3124 | + """ |
| 3125 | + A helper function to match numeric pass_value, tolerance with records value |
| 3126 | +
|
| 3127 | + :param records: List of value to match against |
| 3128 | + :param pass_value: Expected value |
| 3129 | + :param tolerance: Allowed tolerance for match to succeed |
| 3130 | + """ |
| 3131 | + if tolerance: |
| 3132 | + return [ |
| 3133 | + pass_value * (1 - tolerance) <= record <= pass_value * (1 + tolerance) for record in records |
| 3134 | + ] |
| 3135 | + |
| 3136 | + return [record == pass_value for record in records] |
| 3137 | + |
| 3138 | + @staticmethod |
| 3139 | + def _convert_to_float_if_possible(s: Any) -> Any: |
| 3140 | + """ |
| 3141 | + A small helper function to convert a string to a numeric value if appropriate |
| 3142 | +
|
| 3143 | + :param s: the string to be converted |
| 3144 | + """ |
| 3145 | + try: |
| 3146 | + return float(s) |
| 3147 | + except (ValueError, TypeError): |
| 3148 | + return s |
| 3149 | + |
| 3150 | + def interval_check( |
| 3151 | + self, |
| 3152 | + row1: Optional[str], |
| 3153 | + row2: Optional[str], |
| 3154 | + metrics_thresholds: Dict[str, Any], |
| 3155 | + ignore_zero: bool, |
| 3156 | + ratio_formula: str, |
| 3157 | + ) -> None: |
| 3158 | + """ |
| 3159 | + Checks that the values of metrics given as SQL expressions are within a certain tolerance |
| 3160 | +
|
| 3161 | + :param row1: first resulting row of a query execution job for first SQL query |
| 3162 | + :param row2: first resulting row of a query execution job for second SQL query |
| 3163 | + :param metrics_thresholds: a dictionary of ratios indexed by metrics, for |
| 3164 | + example 'COUNT(*)': 1.5 would require a 50 percent or less difference |
| 3165 | + between the current day, and the prior days_back. |
| 3166 | + :param ignore_zero: whether we should ignore zero metrics |
| 3167 | + :param ratio_formula: which formula to use to compute the ratio between |
| 3168 | + the two metrics. Assuming cur is the metric of today and ref is |
| 3169 | + the metric to today - days_back. |
| 3170 | + max_over_min: computes max(cur, ref) / min(cur, ref) |
| 3171 | + relative_diff: computes abs(cur-ref) / ref |
| 3172 | + """ |
| 3173 | + if not row2: |
| 3174 | + raise AirflowException("The second SQL query returned None") |
| 3175 | + if not row1: |
| 3176 | + raise AirflowException("The first SQL query returned None") |
| 3177 | + |
| 3178 | + ratio_formulas = { |
| 3179 | + "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref), |
| 3180 | + "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref, |
| 3181 | + } |
| 3182 | + |
| 3183 | + metrics_sorted = sorted(metrics_thresholds.keys()) |
| 3184 | + |
| 3185 | + current = dict(zip(metrics_sorted, row1)) |
| 3186 | + reference = dict(zip(metrics_sorted, row2)) |
| 3187 | + ratios: Dict[str, Any] = {} |
| 3188 | + test_results: Dict[str, Any] = {} |
| 3189 | + |
| 3190 | + for metric in metrics_sorted: |
| 3191 | + cur = float(current[metric]) |
| 3192 | + ref = float(reference[metric]) |
| 3193 | + threshold = float(metrics_thresholds[metric]) |
| 3194 | + if cur == 0 or ref == 0: |
| 3195 | + ratios[metric] = None |
| 3196 | + test_results[metric] = ignore_zero |
| 3197 | + else: |
| 3198 | + ratios[metric] = ratio_formulas[ratio_formula]( |
| 3199 | + float(current[metric]), float(reference[metric]) |
| 3200 | + ) |
| 3201 | + test_results[metric] = float(ratios[metric]) < threshold |
| 3202 | + |
| 3203 | + self.log.info( |
| 3204 | + ( |
| 3205 | + "Current metric for %s: %s\n" |
| 3206 | + "Past metric for %s: %s\n" |
| 3207 | + "Ratio for %s: %s\n" |
| 3208 | + "Threshold: %s\n" |
| 3209 | + ), |
| 3210 | + metric, |
| 3211 | + cur, |
| 3212 | + metric, |
| 3213 | + ref, |
| 3214 | + metric, |
| 3215 | + ratios[metric], |
| 3216 | + threshold, |
| 3217 | + ) |
| 3218 | + |
| 3219 | + if not all(test_results.values()): |
| 3220 | + failed_tests = [metric for metric, value in test_results.items() if not value] |
| 3221 | + self.log.warning( |
| 3222 | + "The following %s tests out of %s failed:", |
| 3223 | + len(failed_tests), |
| 3224 | + len(metrics_sorted), |
| 3225 | + ) |
| 3226 | + for k in failed_tests: |
| 3227 | + self.log.warning( |
| 3228 | + "'%s' check failed. %s is above %s", |
| 3229 | + k, |
| 3230 | + ratios[k], |
| 3231 | + metrics_thresholds[k], |
| 3232 | + ) |
| 3233 | + raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}") |
| 3234 | + |
| 3235 | + self.log.info("All tests have passed") |
| 3236 | + |
| 3237 | + |
| 3238 | +class BigQueryTableAsyncHook(GoogleBaseAsyncHook): |
| 3239 | + """Class to get async hook for Bigquery Table Async""" |
| 3240 | + |
| 3241 | + sync_hook_class = BigQueryHook |
| 3242 | + |
| 3243 | + async def get_table_client( |
| 3244 | + self, dataset: str, table_id: str, project_id: str, session: ClientSession |
| 3245 | + ) -> Table_async: |
| 3246 | + """ |
| 3247 | + Returns a Google Big Query Table object. |
| 3248 | +
|
| 3249 | + :param dataset: The name of the dataset in which to look for the table storage bucket. |
| 3250 | + :param table_id: The name of the table to check the existence of. |
| 3251 | + :param project_id: The Google cloud project in which to look for the table. |
| 3252 | + The connection supplied to the hook must provide |
| 3253 | + access to the specified project. |
| 3254 | + :param session: aiohttp ClientSession |
| 3255 | + """ |
| 3256 | + with await self.service_file_as_context() as file: |
| 3257 | + return Table_async( |
| 3258 | + dataset_name=dataset, |
| 3259 | + table_name=table_id, |
| 3260 | + project=project_id, |
| 3261 | + service_file=file, |
| 3262 | + session=cast(Session, session), |
| 3263 | + ) |
0 commit comments