Skip to content

Commit 5d4abbd

Browse files
authored
Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)
* Deprecate hql parameters and synchronize DBApiHook method APIs Various providers deriving from DbApi had some variations in some methods that were derived from the common DbApi Hook. Mostly they were about extra parameters added and hql parameter used instead of sql. This prevents from really "common" approach in DbApiHook as some common sql operators rely on signatures being the same. This introduced breaking changes in a few providers - but those breaking changes are easy to fix and most have already been deprecated.
1 parent 89af516 commit 5d4abbd

File tree

17 files changed

+132
-276
lines changed

17 files changed

+132
-276
lines changed

β€Žairflow/providers/apache/hive/CHANGELOG.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@
2424
Changelog
2525
---------
2626

27+
Breaking Changes
28+
~~~~~~~~~~~~~~~~
29+
30+
* The ``hql`` parameter in ``get_records`` of ``HiveServer2Hook`` has been renamed to sql to match the
31+
``get_records`` DbApiHook signature. If you used it as a positional parameter, this is no change for you,
32+
but if you used it as keyword one, you need to rename it.
33+
* ``hive_conf`` parameter has been renamed to ``parameters`` and it is now second parameter, to match ``get_records``
34+
signature from the DbApiHook. You need to rename it if you used it.
35+
* ``schema`` parameter in ``get_records`` is an optional kwargs extra parameter that you can add, to match
36+
the schema of ``get_records`` from DbApiHook.
37+
2738
3.1.0
2839
.....
2940

β€Žairflow/providers/apache/hive/hooks/hive.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import warnings
2525
from collections import OrderedDict
2626
from tempfile import NamedTemporaryFile, TemporaryDirectory
27-
from typing import Any, Dict, List, Optional, Union
27+
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
2828

2929
import pandas
3030
import unicodecsv as csv
@@ -857,15 +857,15 @@ def get_conn(self, schema: Optional[str] = None) -> Any:
857857

858858
def _get_results(
859859
self,
860-
hql: Union[str, List[str]],
860+
sql: Union[str, List[str]],
861861
schema: str = 'default',
862862
fetch_size: Optional[int] = None,
863-
hive_conf: Optional[Dict[Any, Any]] = None,
863+
hive_conf: Optional[Union[Iterable, Mapping]] = None,
864864
) -> Any:
865865
from pyhive.exc import ProgrammingError
866866

867-
if isinstance(hql, str):
868-
hql = [hql]
867+
if isinstance(sql, str):
868+
sql = [sql]
869869
previous_description = None
870870
with contextlib.closing(self.get_conn(schema)) as conn, contextlib.closing(conn.cursor()) as cur:
871871

@@ -882,7 +882,7 @@ def _get_results(
882882
for k, v in env_context.items():
883883
cur.execute(f"set {k}={v}")
884884

885-
for statement in hql:
885+
for statement in sql:
886886
cur.execute(statement)
887887
# we only get results of statements that returns
888888
lowered_statement = statement.lower().strip()
@@ -911,29 +911,29 @@ def _get_results(
911911

912912
def get_results(
913913
self,
914-
hql: str,
914+
sql: Union[str, List[str]],
915915
schema: str = 'default',
916916
fetch_size: Optional[int] = None,
917-
hive_conf: Optional[Dict[Any, Any]] = None,
917+
hive_conf: Optional[Union[Iterable, Mapping]] = None,
918918
) -> Dict[str, Any]:
919919
"""
920920
Get results of the provided hql in target schema.
921921
922-
:param hql: hql to be executed.
922+
:param sql: hql to be executed.
923923
:param schema: target schema, default to 'default'.
924924
:param fetch_size: max size of result to fetch.
925925
:param hive_conf: hive_conf to execute alone with the hql.
926926
:return: results of hql execution, dict with data (list of results) and header
927927
:rtype: dict
928928
"""
929-
results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
929+
results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
930930
header = next(results_iter)
931931
results = {'data': list(results_iter), 'header': header}
932932
return results
933933

934934
def to_csv(
935935
self,
936-
hql: str,
936+
sql: str,
937937
csv_filepath: str,
938938
schema: str = 'default',
939939
delimiter: str = ',',
@@ -945,7 +945,7 @@ def to_csv(
945945
"""
946946
Execute hql in target schema and write results to a csv file.
947947
948-
:param hql: hql to be executed.
948+
:param sql: hql to be executed.
949949
:param csv_filepath: filepath of csv to write results into.
950950
:param schema: target schema, default to 'default'.
951951
:param delimiter: delimiter of the csv file, default to ','.
@@ -955,7 +955,7 @@ def to_csv(
955955
:param hive_conf: hive_conf to execute alone with the hql.
956956
957957
"""
958-
results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
958+
results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
959959
header = next(results_iter)
960960
message = None
961961

@@ -982,14 +982,14 @@ def to_csv(
982982
self.log.info("Done. Loaded a total of %s rows.", i)
983983

984984
def get_records(
985-
self, hql: str, schema: str = 'default', hive_conf: Optional[Dict[Any, Any]] = None
985+
self, sql: Union[str, List[str]], parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs
986986
) -> Any:
987987
"""
988-
Get a set of records from a Hive query.
988+
Get a set of records from a Hive query. You can optionally pass 'schema' kwarg
989+
which specifies target schema and default to 'default'.
989990
990-
:param hql: hql to be executed.
991-
:param schema: target schema, default to 'default'.
992-
:param hive_conf: hive_conf to execute alone with the hql.
991+
:param sql: hql to be executed.
992+
:param parameters: optional configuration passed to get_results
993993
:return: result of hive execution
994994
:rtype: list
995995
@@ -998,19 +998,20 @@ def get_records(
998998
>>> len(hh.get_records(sql))
999999
100
10001000
"""
1001-
return self.get_results(hql, schema=schema, hive_conf=hive_conf)['data']
1001+
schema = kwargs['schema'] if 'schema' in kwargs else 'default'
1002+
return self.get_results(sql, schema=schema, hive_conf=parameters)['data']
10021003

10031004
def get_pandas_df( # type: ignore
10041005
self,
1005-
hql: str,
1006+
sql: str,
10061007
schema: str = 'default',
10071008
hive_conf: Optional[Dict[Any, Any]] = None,
10081009
**kwargs,
10091010
) -> pandas.DataFrame:
10101011
"""
10111012
Get a pandas dataframe from a Hive query
10121013
1013-
:param hql: hql to be executed.
1014+
:param sql: hql to be executed.
10141015
:param schema: target schema, default to 'default'.
10151016
:param hive_conf: hive_conf to execute alone with the hql.
10161017
:param kwargs: (optional) passed into pandas.DataFrame constructor
@@ -1025,6 +1026,6 @@ def get_pandas_df( # type: ignore
10251026
10261027
:return: pandas.DateFrame
10271028
"""
1028-
res = self.get_results(hql, schema=schema, hive_conf=hive_conf)
1029+
res = self.get_results(sql, schema=schema, hive_conf=hive_conf)
10291030
df = pandas.DataFrame(res['data'], columns=[c[0] for c in res['header']], **kwargs)
10301031
return df

β€Žairflow/providers/apache/hive/operators/hive_stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def execute(self, context: "Context") -> None:
138138

139139
presto = PrestoHook(presto_conn_id=self.presto_conn_id)
140140
self.log.info('Executing SQL check: %s', sql)
141-
row = presto.get_first(hql=sql)
141+
row = presto.get_first(sql)
142142
self.log.info("Record: %s", row)
143143
if not row:
144144
raise AirflowException("The query returned None")

β€Žairflow/providers/apache/hive/transfers/hive_to_mysql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def execute(self, context: 'Context'):
111111
mysql = self._call_preoperator()
112112
mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name)
113113
else:
114-
hive_results = hive.get_records(self.sql, hive_conf=hive_conf)
114+
hive_results = hive.get_records(self.sql, parameters=hive_conf)
115115
mysql = self._call_preoperator()
116116
mysql.insert_rows(table=self.mysql_table, rows=hive_results)
117117

β€Žairflow/providers/apache/hive/transfers/hive_to_samba.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def execute(self, context: 'Context'):
6868
with NamedTemporaryFile() as tmp_file:
6969
self.log.info("Fetching file from Hive")
7070
hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
71-
hive.to_csv(hql=self.hql, csv_filepath=tmp_file.name, hive_conf=context_to_airflow_vars(context))
71+
hive.to_csv(self.hql, csv_filepath=tmp_file.name, hive_conf=context_to_airflow_vars(context))
7272
self.log.info("Pushing to samba")
7373
samba = SambaHook(samba_conn_id=self.samba_conn_id)
7474
samba.push_from_local(self.destination_filepath, tmp_file.name)

β€Žairflow/providers/apache/pinot/hooks/pinot.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,9 @@ def get_uri(self) -> str:
275275
endpoint = conn.extra_dejson.get('endpoint', 'query/sql')
276276
return f'{conn_type}://{host}/{endpoint}'
277277

278-
def get_records(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
278+
def get_records(
279+
self, sql: Union[str, List[str]], parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs
280+
) -> Any:
279281
"""
280282
Executes the sql and returns a set of records.
281283
@@ -287,7 +289,9 @@ def get_records(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] =
287289
cur.execute(sql)
288290
return cur.fetchall()
289291

290-
def get_first(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
292+
def get_first(
293+
self, sql: Union[str, List[str]], parameters: Optional[Union[Iterable, Mapping]] = None
294+
) -> Any:
291295
"""
292296
Executes the sql and returns the first resulting row.
293297

β€Žairflow/providers/common/sql/CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
specific language governing permissions and limitations
1616
under the License.
1717
18+
.. NOTE TO CONTRIBUTORS:
19+
Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
20+
and you want to add an explanation to the users on how they are supposed to deal with them.
21+
The changelog is updated and maintained semi-automatically by release manager.
22+
1823
1924
Changelog
2025
---------

β€Žairflow/providers/common/sql/hooks/sql.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,12 @@ def get_pandas_df_by_chunks(self, sql, parameters=None, *, chunksize, **kwargs):
181181
with closing(self.get_conn()) as conn:
182182
yield from psql.read_sql(sql, con=conn, params=parameters, chunksize=chunksize, **kwargs)
183183

184-
def get_records(self, sql, parameters=None):
184+
def get_records(
185+
self,
186+
sql: Union[str, List[str]],
187+
parameters: Optional[Union[Iterable, Mapping]] = None,
188+
**kwargs: dict,
189+
):
185190
"""
186191
Executes the sql and returns a set of records.
187192
@@ -197,7 +202,7 @@ def get_records(self, sql, parameters=None):
197202
cur.execute(sql)
198203
return cur.fetchall()
199204

200-
def get_first(self, sql, parameters=None):
205+
def get_first(self, sql: Union[str, List[str]], parameters=None):
201206
"""
202207
Executes the sql and returns the first resulting row.
203208

β€Žairflow/providers/exasol/hooks/exasol.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@ def get_pandas_df(self, sql: str, parameters: Optional[dict] = None, **kwargs) -
7777
df = conn.export_to_pandas(sql, query_params=parameters, **kwargs)
7878
return df
7979

80-
def get_records(self, sql: str, parameters: Optional[dict] = None) -> List[Union[dict, Tuple[Any, ...]]]:
80+
def get_records(
81+
self,
82+
sql: Union[str, List[str]],
83+
parameters: Optional[Union[Iterable, Mapping]] = None,
84+
**kwargs: dict,
85+
) -> List[Union[dict, Tuple[Any, ...]]]:
8186
"""
8287
Executes the sql and returns a set of records.
8388
@@ -89,7 +94,7 @@ def get_records(self, sql: str, parameters: Optional[dict] = None) -> List[Union
8994
with closing(conn.execute(sql, parameters)) as cur:
9095
return cur.fetchall()
9196

92-
def get_first(self, sql: str, parameters: Optional[dict] = None) -> Optional[Any]:
97+
def get_first(self, sql: Union[str, List[str]], parameters: Optional[dict] = None) -> Optional[Any]:
9398
"""
9499
Executes the sql and returns the first resulting row.
95100

β€Žairflow/providers/google/cloud/hooks/cloud_sql.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -426,11 +426,11 @@ def __init__(
426426
self.sql_proxy_was_downloaded = False
427427
self.sql_proxy_version = sql_proxy_version
428428
self.download_sql_proxy_dir = None
429-
self.sql_proxy_process = None # type: Optional[Popen]
429+
self.sql_proxy_process: Optional[Popen] = None
430430
self.instance_specification = instance_specification
431431
self.project_id = project_id
432432
self.gcp_conn_id = gcp_conn_id
433-
self.command_line_parameters = [] # type: List[str]
433+
self.command_line_parameters: List[str] = []
434434
self.cloud_sql_proxy_socket_directory = self.path_prefix
435435
self.sql_proxy_path = (
436436
sql_proxy_binary_path if sql_proxy_binary_path else self.path_prefix + "_cloud_sql_proxy"
@@ -705,28 +705,28 @@ def __init__(
705705
self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id
706706
self.cloudsql_connection = self.get_connection(self.gcp_cloudsql_conn_id)
707707
self.extras = self.cloudsql_connection.extra_dejson
708-
self.project_id = self.extras.get('project_id', default_gcp_project_id) # type: Optional[str]
709-
self.instance = self.extras.get('instance') # type: Optional[str]
710-
self.database = self.cloudsql_connection.schema # type: Optional[str]
711-
self.location = self.extras.get('location') # type: Optional[str]
712-
self.database_type = self.extras.get('database_type') # type: Optional[str]
713-
self.use_proxy = self._get_bool(self.extras.get('use_proxy', 'False')) # type: bool
714-
self.use_ssl = self._get_bool(self.extras.get('use_ssl', 'False')) # type: bool
715-
self.sql_proxy_use_tcp = self._get_bool(self.extras.get('sql_proxy_use_tcp', 'False')) # type: bool
716-
self.sql_proxy_version = self.extras.get('sql_proxy_version') # type: Optional[str]
717-
self.sql_proxy_binary_path = self.extras.get('sql_proxy_binary_path') # type: Optional[str]
718-
self.user = self.cloudsql_connection.login # type: Optional[str]
719-
self.password = self.cloudsql_connection.password # type: Optional[str]
720-
self.public_ip = self.cloudsql_connection.host # type: Optional[str]
721-
self.public_port = self.cloudsql_connection.port # type: Optional[int]
722-
self.sslcert = self.extras.get('sslcert') # type: Optional[str]
723-
self.sslkey = self.extras.get('sslkey') # type: Optional[str]
724-
self.sslrootcert = self.extras.get('sslrootcert') # type: Optional[str]
708+
self.project_id = self.extras.get('project_id', default_gcp_project_id)
709+
self.instance = self.extras.get('instance')
710+
self.database = self.cloudsql_connection.schema
711+
self.location = self.extras.get('location')
712+
self.database_type = self.extras.get('database_type')
713+
self.use_proxy = self._get_bool(self.extras.get('use_proxy', 'False'))
714+
self.use_ssl = self._get_bool(self.extras.get('use_ssl', 'False'))
715+
self.sql_proxy_use_tcp = self._get_bool(self.extras.get('sql_proxy_use_tcp', 'False'))
716+
self.sql_proxy_version = self.extras.get('sql_proxy_version')
717+
self.sql_proxy_binary_path = self.extras.get('sql_proxy_binary_path')
718+
self.user = self.cloudsql_connection.login
719+
self.password = self.cloudsql_connection.password
720+
self.public_ip = self.cloudsql_connection.host
721+
self.public_port = self.cloudsql_connection.port
722+
self.sslcert = self.extras.get('sslcert')
723+
self.sslkey = self.extras.get('sslkey')
724+
self.sslrootcert = self.extras.get('sslrootcert')
725725
# Port and socket path and db_hook are automatically generated
726726
self.sql_proxy_tcp_port = None
727-
self.sql_proxy_unique_path = None # type: Optional[str]
728-
self.db_hook = None # type: Optional[Union[PostgresHook, MySqlHook]]
729-
self.reserved_tcp_socket = None # type: Optional[socket.socket]
727+
self.sql_proxy_unique_path: Optional[str] = None
728+
self.db_hook: Optional[Union[PostgresHook, MySqlHook]] = None
729+
self.reserved_tcp_socket: Optional[socket.socket] = None
730730
# Generated based on clock + clock sequence. Unique per host (!).
731731
# This is important as different hosts share the database
732732
self.db_conn_id = str(uuid.uuid1())
@@ -828,18 +828,18 @@ def _generate_connection_uri(self) -> str:
828828
if not self.database_type:
829829
raise ValueError("The database_type should be set")
830830

831-
database_uris = CONNECTION_URIS[self.database_type] # type: Dict[str, Dict[str, str]]
831+
database_uris = CONNECTION_URIS[self.database_type]
832832
ssl_spec = None
833833
socket_path = None
834834
if self.use_proxy:
835-
proxy_uris = database_uris['proxy'] # type: Dict[str, str]
835+
proxy_uris = database_uris['proxy']
836836
if self.sql_proxy_use_tcp:
837837
format_string = proxy_uris['tcp']
838838
else:
839839
format_string = proxy_uris['socket']
840840
socket_path = f"{self.sql_proxy_unique_path}/{self._get_instance_socket_name()}"
841841
else:
842-
public_uris = database_uris['public'] # type: Dict[str, str]
842+
public_uris = database_uris['public']
843843
if self.use_ssl:
844844
format_string = public_uris['ssl']
845845
ssl_spec = {'cert': self.sslcert, 'key': self.sslkey, 'ca': self.sslrootcert}
@@ -876,7 +876,7 @@ def _generate_connection_uri(self) -> str:
876876
return connection_uri
877877

878878
def _get_instance_socket_name(self) -> str:
879-
return self.project_id + ":" + self.location + ":" + self.instance # type: ignore
879+
return self.project_id + ":" + self.location + ":" + self.instance
880880

881881
def _get_sqlproxy_instance_specification(self) -> str:
882882
instance_specification = self._get_instance_socket_name()
@@ -921,10 +921,13 @@ def get_database_hook(self, connection: Connection) -> Union[PostgresHook, MySql
921921
that uses proxy or connects directly to the Google Cloud SQL database.
922922
"""
923923
if self.database_type == 'postgres':
924-
self.db_hook = PostgresHook(connection=connection, schema=self.database)
924+
db_hook: Union[PostgresHook, MySqlHook] = PostgresHook(
925+
connection=connection, schema=self.database
926+
)
925927
else:
926-
self.db_hook = MySqlHook(connection=connection, schema=self.database)
927-
return self.db_hook
928+
db_hook = MySqlHook(connection=connection, schema=self.database)
929+
self.db_hook = db_hook
930+
return db_hook
928931

929932
def cleanup_database_hook(self) -> None:
930933
"""Clean up database hook after it was used."""

0 commit comments

Comments
 (0)