|
57 | 57 |
|
58 | 58 | from airflow.exceptions import AirflowException
|
59 | 59 | from airflow.providers.common.sql.hooks.sql import DbApiHook
|
| 60 | +from airflow.providers.google.cloud.utils.bigquery import bq_cast |
60 | 61 | from airflow.providers.google.common.consts import CLIENT_INFO
|
61 | 62 | from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field
|
62 | 63 | from airflow.utils.helpers import convert_camel_to_snake
|
@@ -2740,7 +2741,7 @@ def next(self) -> list | None:
|
2740 | 2741 | rows = query_results["rows"]
|
2741 | 2742 |
|
2742 | 2743 | for dict_row in rows:
|
2743 |
| - typed_row = [_bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])] |
| 2744 | + typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])] |
2744 | 2745 | self.buffer.append(typed_row)
|
2745 | 2746 |
|
2746 | 2747 | if not self.page_token:
|
@@ -2845,25 +2846,6 @@ def _escape(s: str) -> str:
|
2845 | 2846 | return e
|
2846 | 2847 |
|
2847 | 2848 |
|
2848 |
| -def _bq_cast(string_field: str, bq_type: str) -> None | int | float | bool | str: |
2849 |
| - """ |
2850 |
| - Helper method that casts a BigQuery row to the appropriate data types. |
2851 |
| - This is useful because BigQuery returns all fields as strings. |
2852 |
| - """ |
2853 |
| - if string_field is None: |
2854 |
| - return None |
2855 |
| - elif bq_type == "INTEGER": |
2856 |
| - return int(string_field) |
2857 |
| - elif bq_type in ("FLOAT", "TIMESTAMP"): |
2858 |
| - return float(string_field) |
2859 |
| - elif bq_type == "BOOLEAN": |
2860 |
| - if string_field not in ["true", "false"]: |
2861 |
| - raise ValueError(f"{string_field} must have value 'true' or 'false'") |
2862 |
| - return string_field == "true" |
2863 |
| - else: |
2864 |
| - return string_field |
2865 |
| - |
2866 |
| - |
2867 | 2849 | def split_tablename(
|
2868 | 2850 | table_input: str, default_project_id: str, var_name: str | None = None
|
2869 | 2851 | ) -> tuple[str, str, str]:
|
@@ -3070,7 +3052,7 @@ def get_records(self, query_results: dict[str, Any]) -> list[Any]:
|
3070 | 3052 | fields = query_results["schema"]["fields"]
|
3071 | 3053 | col_types = [field["type"] for field in fields]
|
3072 | 3054 | for dict_row in rows:
|
3073 |
| - typed_row = [_bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])] |
| 3055 | + typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])] |
3074 | 3056 | buffer.append(typed_row)
|
3075 | 3057 | return buffer
|
3076 | 3058 |
|
|
0 commit comments