Skip to content

Commit 7d2c2ee

Browse files
authored
add description method in BigQueryCursor class (#25366)
]
1 parent e84d753 commit 7d2c2ee

File tree

2 files changed

+74
-17
lines changed

2 files changed

+74
-17
lines changed

β€Žairflow/providers/google/cloud/hooks/bigquery.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2663,11 +2663,16 @@ def __init__(
26632663
self.job_id = None # type: Optional[str]
26642664
self.buffer = [] # type: list
26652665
self.all_pages_loaded = False # type: bool
2666+
self._description = [] # type: List
26662667

26672668
@property
2668-
def description(self) -> None:
2669-
"""The schema description method is not currently implemented"""
2670-
raise NotImplementedError
2669+
def description(self) -> List:
2670+
"""Return the cursor description"""
2671+
return self._description
2672+
2673+
@description.setter
2674+
def description(self, value):
2675+
self._description = value
26712676

26722677
def close(self) -> None:
26732678
"""By default, do nothing"""
@@ -2688,6 +2693,10 @@ def execute(self, operation: str, parameters: Optional[dict] = None) -> None:
26882693
self.flush_results()
26892694
self.job_id = self.hook.run_query(sql)
26902695

2696+
query_results = self._get_query_result()
2697+
description = _format_schema_for_description(query_results["schema"])
2698+
self.description = description
2699+
26912700
def executemany(self, operation: str, seq_of_parameters: list) -> None:
26922701
"""
26932702
Execute a BigQuery query multiple times with different parameters.
@@ -2723,17 +2732,7 @@ def next(self) -> Union[List, None]:
27232732
if self.all_pages_loaded:
27242733
return None
27252734

2726-
query_results = (
2727-
self.service.jobs()
2728-
.getQueryResults(
2729-
projectId=self.project_id,
2730-
jobId=self.job_id,
2731-
location=self.location,
2732-
pageToken=self.page_token,
2733-
)
2734-
.execute(num_retries=self.num_retries)
2735-
)
2736-
2735+
query_results = self._get_query_result()
27372736
if 'rows' in query_results and query_results['rows']:
27382737
self.page_token = query_results.get('pageToken')
27392738
fields = query_results['schema']['fields']
@@ -2805,6 +2804,21 @@ def setinputsizes(self, sizes: Any) -> None:
28052804
def setoutputsize(self, size: Any, column: Any = None) -> None:
28062805
"""Does nothing by default"""
28072806

2807+
def _get_query_result(self) -> Dict:
2808+
"""Get job query results like data, schema, job type..."""
2809+
query_results = (
2810+
self.service.jobs()
2811+
.getQueryResults(
2812+
projectId=self.project_id,
2813+
jobId=self.job_id,
2814+
location=self.location,
2815+
pageToken=self.page_token,
2816+
)
2817+
.execute(num_retries=self.num_retries)
2818+
)
2819+
2820+
return query_results
2821+
28082822

28092823
def _bind_parameters(operation: str, parameters: dict) -> str:
28102824
"""Helper method that binds parameters to a SQL query"""
@@ -2973,3 +2987,23 @@ def _validate_src_fmt_configs(
29732987
raise ValueError(f"{k} is not a valid src_fmt_configs for type {source_format}.")
29742988

29752989
return src_fmt_configs
2990+
2991+
2992+
def _format_schema_for_description(schema: Dict) -> List:
2993+
"""
2994+
Reformat the schema to match cursor description standard which is a tuple
2995+
of 7 elemenbts (name, type, display_size, internal_size, precision, scale, null_ok)
2996+
"""
2997+
description = []
2998+
for field in schema["fields"]:
2999+
field_description = (
3000+
field["name"],
3001+
field["type"],
3002+
None,
3003+
None,
3004+
None,
3005+
None,
3006+
field["mode"] == "NULLABLE",
3007+
)
3008+
description.append(field_description)
3009+
return description

β€Žtests/providers/google/cloud/hooks/test_bigquery.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
BigQueryHook,
3535
_api_resource_configs_duplication_check,
3636
_cleanse_time_partitioning,
37+
_format_schema_for_description,
3738
_validate_src_fmt_configs,
3839
_validate_value,
3940
split_tablename,
@@ -1239,11 +1240,33 @@ def test_execute_many(self, mock_insert, _):
12391240
]
12401241
)
12411242

1243+
def test_format_schema_for_description(self):
1244+
test_query_result = {
1245+
"schema": {
1246+
"fields": [
1247+
{"name": "field_1", "type": "STRING", "mode": "NULLABLE"},
1248+
]
1249+
},
1250+
}
1251+
description = _format_schema_for_description(test_query_result["schema"])
1252+
assert description == [('field_1', 'STRING', None, None, None, None, True)]
1253+
12421254
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
1243-
def test_description(self, mock_get_service):
1255+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
1256+
def test_description(self, mock_insert, mock_get_service):
1257+
mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults
1258+
mock_execute = mock_get_query_results.return_value.execute
1259+
mock_execute.return_value = {
1260+
"schema": {
1261+
"fields": [
1262+
{"name": "ts", "type": "TIMESTAMP", "mode": "NULLABLE"},
1263+
]
1264+
},
1265+
}
1266+
12441267
bq_cursor = self.hook.get_cursor()
1245-
with pytest.raises(NotImplementedError):
1246-
bq_cursor.description
1268+
bq_cursor.execute("SELECT CURRENT_TIMESTAMP() as ts")
1269+
assert bq_cursor.description == [("ts", "TIMESTAMP", None, None, None, None, True)]
12471270

12481271
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
12491272
def test_close(self, mock_get_service):

0 commit comments

Comments
 (0)