Skip to content

Commit 8e88eb8

Browse files
authored
Consolidate import and usage of pandas (#33480)
1 parent 47187ce commit 8e88eb8

File tree

8 files changed

+27
-27
lines changed

8 files changed

+27
-27
lines changed

β€Žairflow/providers/amazon/aws/transfers/sql_to_s3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from airflow.providers.common.sql.hooks.sql import DbApiHook
3030

3131
if TYPE_CHECKING:
32-
from pandas import DataFrame
32+
import pandas as pd
3333

3434
from airflow.utils.context import Context
3535

@@ -134,15 +134,15 @@ def __init__(
134134
raise AirflowException(f"The argument file_format doesn't support {file_format} value.")
135135

136136
@staticmethod
137-
def _fix_dtypes(df: DataFrame, file_format: FILE_FORMAT) -> None:
137+
def _fix_dtypes(df: pd.DataFrame, file_format: FILE_FORMAT) -> None:
138138
"""
139139
Mutate DataFrame to set dtypes for float columns containing NaN values.
140140
141141
Set dtype of object to str to allow for downstream transformations.
142142
"""
143143
try:
144144
import numpy as np
145-
from pandas import Float64Dtype, Int64Dtype
145+
import pandas as pd
146146
except ImportError as e:
147147
from airflow.exceptions import AirflowOptionalProviderFeatureException
148148

@@ -163,13 +163,13 @@ def _fix_dtypes(df: DataFrame, file_format: FILE_FORMAT) -> None:
163163
# The type ignore can be removed here if https://github.com/numpy/numpy/pull/23690
164164
# is merged and released as currently NumPy does not consider None as valid for x/y.
165165
df[col] = np.where(df[col].isnull(), None, df[col]) # type: ignore[call-overload]
166-
df[col] = df[col].astype(Int64Dtype())
166+
df[col] = df[col].astype(pd.Int64Dtype())
167167
elif np.isclose(notna_series, notna_series.astype(int)).all():
168168
# set to float dtype that retains floats and supports NaNs
169169
# The type ignore can be removed here if https://github.com/numpy/numpy/pull/23690
170170
# is merged and released
171171
df[col] = np.where(df[col].isnull(), None, df[col]) # type: ignore[call-overload]
172-
df[col] = df[col].astype(Float64Dtype())
172+
df[col] = df[col].astype(pd.Float64Dtype())
173173

174174
def execute(self, context: Context) -> None:
175175
sql_hook = self._get_hook()
@@ -192,7 +192,7 @@ def execute(self, context: Context) -> None:
192192
filename=tmp_file.name, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
193193
)
194194

195-
def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]]:
195+
def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]:
196196
"""Partition dataframe using pandas groupby() method."""
197197
if not self.groupby_kwargs:
198198
yield "", df

β€Žairflow/providers/apache/hive/hooks/hive.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from airflow.exceptions import AirflowProviderDeprecationWarning
3232

3333
try:
34-
import pandas
34+
import pandas as pd
3535
except ImportError as e:
3636
from airflow.exceptions import AirflowOptionalProviderFeatureException
3737

@@ -336,7 +336,7 @@ def test_hql(self, hql: str) -> None:
336336

337337
def load_df(
338338
self,
339-
df: pandas.DataFrame,
339+
df: pd.DataFrame,
340340
table: str,
341341
field_dict: dict[Any, Any] | None = None,
342342
delimiter: str = ",",
@@ -361,7 +361,7 @@ def load_df(
361361
:param kwargs: passed to self.load_file
362362
"""
363363

364-
def _infer_field_types_from_df(df: pandas.DataFrame) -> dict[Any, Any]:
364+
def _infer_field_types_from_df(df: pd.DataFrame) -> dict[Any, Any]:
365365
dtype_kind_hive_type = {
366366
"b": "BOOLEAN", # boolean
367367
"i": "BIGINT", # signed integer
@@ -1037,7 +1037,7 @@ def get_pandas_df( # type: ignore
10371037
schema: str = "default",
10381038
hive_conf: dict[Any, Any] | None = None,
10391039
**kwargs,
1040-
) -> pandas.DataFrame:
1040+
) -> pd.DataFrame:
10411041
"""
10421042
Get a pandas dataframe from a Hive query.
10431043
@@ -1056,5 +1056,5 @@ def get_pandas_df( # type: ignore
10561056
:return: pandas.DateFrame
10571057
"""
10581058
res = self.get_results(sql, schema=schema, hive_conf=hive_conf)
1059-
df = pandas.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs)
1059+
df = pd.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs)
10601060
return df

β€Žairflow/providers/google/cloud/hooks/bigquery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from datetime import datetime, timedelta
3131
from typing import Any, Iterable, Mapping, NoReturn, Sequence, Union, cast
3232

33+
import pandas as pd
3334
from aiohttp import ClientSession as ClientSession
3435
from gcloud.aio.bigquery import Job, Table as Table_async
3536
from google.api_core.page_iterator import HTTPIterator
@@ -49,7 +50,6 @@
4950
from google.cloud.bigquery.table import EncryptionConfiguration, Row, RowIterator, Table, TableReference
5051
from google.cloud.exceptions import NotFound
5152
from googleapiclient.discovery import Resource, build
52-
from pandas import DataFrame
5353
from pandas_gbq import read_gbq
5454
from pandas_gbq.gbq import GbqConnector # noqa
5555
from requests import Session
@@ -244,7 +244,7 @@ def get_pandas_df(
244244
parameters: Iterable | Mapping[str, Any] | None = None,
245245
dialect: str | None = None,
246246
**kwargs,
247-
) -> DataFrame:
247+
) -> pd.DataFrame:
248248
"""Get a Pandas DataFrame for the BigQuery results.
249249
250250
The DbApiHook method must be overridden because Pandas doesn't support

β€Žairflow/providers/presto/hooks/presto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def get_first(
158158
raise PrestoException(e)
159159

160160
def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
161-
import pandas
161+
import pandas as pd
162162

163163
cursor = self.get_cursor()
164164
try:
@@ -168,10 +168,10 @@ def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
168168
raise PrestoException(e)
169169
column_descriptions = cursor.description
170170
if data:
171-
df = pandas.DataFrame(data, **kwargs)
171+
df = pd.DataFrame(data, **kwargs)
172172
df.columns = [c[0] for c in column_descriptions]
173173
else:
174-
df = pandas.DataFrame(**kwargs)
174+
df = pd.DataFrame(**kwargs)
175175
return df
176176

177177
def insert_rows(

β€Žairflow/providers/slack/transfers/sql_to_slack.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tempfile import NamedTemporaryFile
2020
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence
2121

22-
from pandas import DataFrame
22+
import pandas as pd
2323
from tabulate import tabulate
2424

2525
from airflow.exceptions import AirflowException
@@ -70,7 +70,7 @@ def _get_hook(self) -> DbApiHook:
7070
)
7171
return hook
7272

73-
def _get_query_results(self) -> DataFrame:
73+
def _get_query_results(self) -> pd.DataFrame:
7474
sql_hook = self._get_hook()
7575

7676
self.log.info("Running SQL query: %s", self.sql)

β€Žairflow/providers/trino/hooks/trino.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def get_first(
178178
def get_pandas_df(
179179
self, sql: str = "", parameters: Iterable | Mapping[str, Any] | None = None, **kwargs
180180
): # type: ignore[override]
181-
import pandas
181+
import pandas as pd
182182

183183
cursor = self.get_cursor()
184184
try:
@@ -188,10 +188,10 @@ def get_pandas_df(
188188
raise TrinoException(e)
189189
column_descriptions = cursor.description
190190
if data:
191-
df = pandas.DataFrame(data, **kwargs)
191+
df = pd.DataFrame(data, **kwargs)
192192
df.columns = [c[0] for c in column_descriptions]
193193
else:
194-
df = pandas.DataFrame(**kwargs)
194+
df = pd.DataFrame(**kwargs)
195195
return df
196196

197197
def insert_rows(

β€Žairflow/serialization/serializers/pandas.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@
2828
deserializers = serializers
2929

3030
if TYPE_CHECKING:
31-
from pandas import DataFrame
31+
import pandas as pd
3232

3333
from airflow.serialization.serde import U
3434

3535
__version__ = 1
3636

3737

3838
def serialize(o: object) -> tuple[U, str, int, bool]:
39+
import pandas as pd
3940
import pyarrow as pa
40-
from pandas import DataFrame
4141
from pyarrow import parquet as pq
4242

43-
if not isinstance(o, DataFrame):
43+
if not isinstance(o, pd.DataFrame):
4444
return "", "", 0, False
4545

4646
# for now, we *always* serialize into in memory
@@ -53,7 +53,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
5353
return buf.getvalue().hex().decode("utf-8"), qualname(o), __version__, True
5454

5555

56-
def deserialize(classname: str, version: int, data: object) -> DataFrame:
56+
def deserialize(classname: str, version: int, data: object) -> pd.DataFrame:
5757
if version > __version__:
5858
raise TypeError(f"serialized {version} of {classname} > {__version__}")
5959

β€Žtests/serialization/serializers/test_serializers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import decimal
2121

2222
import numpy
23-
import pandas
23+
import pandas as pd
2424
import pendulum.tz
2525
import pytest
2626
from pendulum import DateTime
@@ -94,7 +94,7 @@ def test_params(self):
9494
assert i["x"] == d["x"]
9595

9696
def test_pandas(self):
97-
i = pandas.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
97+
i = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
9898
e = serialize(i)
9999
d = deserialize(e)
100100
assert i.equals(d)

0 commit comments

Comments
 (0)