Skip to content

Commit 6d182be

Browse files
Use a single statement with multiple contexts instead of nested statements in providers (#33768)
1 parent 4bae275 commit 6d182be

File tree

11 files changed

+213
-228
lines changed

11 files changed

+213
-228
lines changed

β€Žairflow/providers/apache/hive/hooks/hive.py

Lines changed: 61 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -236,58 +236,55 @@ def run_cli(
236236
if schema:
237237
hql = f"USE {schema};\n{hql}"
238238

239-
with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir:
240-
with NamedTemporaryFile(dir=tmp_dir) as f:
241-
hql += "\n"
242-
f.write(hql.encode("UTF-8"))
243-
f.flush()
244-
hive_cmd = self._prepare_cli_cmd()
245-
env_context = get_context_from_env_var()
246-
# Only extend the hive_conf if it is defined.
247-
if hive_conf:
248-
env_context.update(hive_conf)
249-
hive_conf_params = self._prepare_hiveconf(env_context)
250-
if self.mapred_queue:
251-
hive_conf_params.extend(
252-
[
253-
"-hiveconf",
254-
f"mapreduce.job.queuename={self.mapred_queue}",
255-
"-hiveconf",
256-
f"mapred.job.queue.name={self.mapred_queue}",
257-
"-hiveconf",
258-
f"tez.queue.name={self.mapred_queue}",
259-
]
260-
)
239+
with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir, NamedTemporaryFile(dir=tmp_dir) as f:
240+
hql += "\n"
241+
f.write(hql.encode("UTF-8"))
242+
f.flush()
243+
hive_cmd = self._prepare_cli_cmd()
244+
env_context = get_context_from_env_var()
245+
# Only extend the hive_conf if it is defined.
246+
if hive_conf:
247+
env_context.update(hive_conf)
248+
hive_conf_params = self._prepare_hiveconf(env_context)
249+
if self.mapred_queue:
250+
hive_conf_params.extend(
251+
[
252+
"-hiveconf",
253+
f"mapreduce.job.queuename={self.mapred_queue}",
254+
"-hiveconf",
255+
f"mapred.job.queue.name={self.mapred_queue}",
256+
"-hiveconf",
257+
f"tez.queue.name={self.mapred_queue}",
258+
]
259+
)
261260

262-
if self.mapred_queue_priority:
263-
hive_conf_params.extend(
264-
["-hiveconf", f"mapreduce.job.priority={self.mapred_queue_priority}"]
265-
)
261+
if self.mapred_queue_priority:
262+
hive_conf_params.extend(["-hiveconf", f"mapreduce.job.priority={self.mapred_queue_priority}"])
266263

267-
if self.mapred_job_name:
268-
hive_conf_params.extend(["-hiveconf", f"mapred.job.name={self.mapred_job_name}"])
264+
if self.mapred_job_name:
265+
hive_conf_params.extend(["-hiveconf", f"mapred.job.name={self.mapred_job_name}"])
269266

270-
hive_cmd.extend(hive_conf_params)
271-
hive_cmd.extend(["-f", f.name])
267+
hive_cmd.extend(hive_conf_params)
268+
hive_cmd.extend(["-f", f.name])
272269

270+
if verbose:
271+
self.log.info("%s", " ".join(hive_cmd))
272+
sub_process: Any = subprocess.Popen(
273+
hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
274+
)
275+
self.sub_process = sub_process
276+
stdout = ""
277+
for line in iter(sub_process.stdout.readline, b""):
278+
line = line.decode()
279+
stdout += line
273280
if verbose:
274-
self.log.info("%s", " ".join(hive_cmd))
275-
sub_process: Any = subprocess.Popen(
276-
hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
277-
)
278-
self.sub_process = sub_process
279-
stdout = ""
280-
for line in iter(sub_process.stdout.readline, b""):
281-
line = line.decode()
282-
stdout += line
283-
if verbose:
284-
self.log.info(line.strip())
285-
sub_process.wait()
281+
self.log.info(line.strip())
282+
sub_process.wait()
286283

287-
if sub_process.returncode:
288-
raise AirflowException(stdout)
284+
if sub_process.returncode:
285+
raise AirflowException(stdout)
289286

290-
return stdout
287+
return stdout
291288

292289
def test_hql(self, hql: str) -> None:
293290
"""Test an hql statement using the hive cli and EXPLAIN."""
@@ -376,25 +373,26 @@ def _infer_field_types_from_df(df: pd.DataFrame) -> dict[Any, Any]:
376373
if pandas_kwargs is None:
377374
pandas_kwargs = {}
378375

379-
with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir:
380-
with NamedTemporaryFile(dir=tmp_dir, mode="w") as f:
381-
if field_dict is None:
382-
field_dict = _infer_field_types_from_df(df)
383-
384-
df.to_csv(
385-
path_or_buf=f,
386-
sep=delimiter,
387-
header=False,
388-
index=False,
389-
encoding=encoding,
390-
date_format="%Y-%m-%d %H:%M:%S",
391-
**pandas_kwargs,
392-
)
393-
f.flush()
376+
with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir, NamedTemporaryFile(
377+
dir=tmp_dir, mode="w"
378+
) as f:
379+
if field_dict is None:
380+
field_dict = _infer_field_types_from_df(df)
381+
382+
df.to_csv(
383+
path_or_buf=f,
384+
sep=delimiter,
385+
header=False,
386+
index=False,
387+
encoding=encoding,
388+
date_format="%Y-%m-%d %H:%M:%S",
389+
**pandas_kwargs,
390+
)
391+
f.flush()
394392

395-
return self.load_file(
396-
filepath=f.name, table=table, delimiter=delimiter, field_dict=field_dict, **kwargs
397-
)
393+
return self.load_file(
394+
filepath=f.name, table=table, delimiter=delimiter, field_dict=field_dict, **kwargs
395+
)
398396

399397
def load_file(
400398
self,

β€Žairflow/providers/apache/hive/transfers/mysql_to_hive.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,21 +136,20 @@ def execute(self, context: Context):
136136
mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
137137
self.log.info("Dumping MySQL query results to local file")
138138
with NamedTemporaryFile(mode="w", encoding="utf-8") as f:
139-
with closing(mysql.get_conn()) as conn:
140-
with closing(conn.cursor()) as cursor:
141-
cursor.execute(self.sql)
142-
csv_writer = csv.writer(
143-
f,
144-
delimiter=self.delimiter,
145-
quoting=self.quoting,
146-
quotechar=self.quotechar if self.quoting != csv.QUOTE_NONE else None,
147-
escapechar=self.escapechar,
148-
)
149-
field_dict = {}
150-
if cursor.description is not None:
151-
for field in cursor.description:
152-
field_dict[field[0]] = self.type_map(field[1])
153-
csv_writer.writerows(cursor) # type: ignore[arg-type]
139+
with closing(mysql.get_conn()) as conn, closing(conn.cursor()) as cursor:
140+
cursor.execute(self.sql)
141+
csv_writer = csv.writer(
142+
f,
143+
delimiter=self.delimiter,
144+
quoting=self.quoting,
145+
quotechar=self.quotechar if self.quoting != csv.QUOTE_NONE else None,
146+
escapechar=self.escapechar,
147+
)
148+
field_dict = {}
149+
if cursor.description is not None:
150+
for field in cursor.description:
151+
field_dict[field[0]] = self.type_map(field[1])
152+
csv_writer.writerows(cursor) # type: ignore[arg-type]
154153
f.flush()
155154
self.log.info("Loading file into Hive")
156155
hive.load_file(

β€Žairflow/providers/apache/pig/hooks/pig.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -64,41 +64,40 @@ def run_cli(self, pig: str, pig_opts: str | None = None, verbose: bool = True) -
6464
>>> ("hdfs://" in result)
6565
True
6666
"""
67-
with TemporaryDirectory(prefix="airflow_pigop_") as tmp_dir:
68-
with NamedTemporaryFile(dir=tmp_dir) as f:
69-
f.write(pig.encode("utf-8"))
70-
f.flush()
71-
fname = f.name
72-
pig_bin = "pig"
73-
cmd_extra: list[str] = []
74-
75-
pig_cmd = [pig_bin]
76-
77-
if self.pig_properties:
78-
pig_cmd.extend(self.pig_properties)
79-
if pig_opts:
80-
pig_opts_list = pig_opts.split()
81-
pig_cmd.extend(pig_opts_list)
67+
with TemporaryDirectory(prefix="airflow_pigop_") as tmp_dir, NamedTemporaryFile(dir=tmp_dir) as f:
68+
f.write(pig.encode("utf-8"))
69+
f.flush()
70+
fname = f.name
71+
pig_bin = "pig"
72+
cmd_extra: list[str] = []
73+
74+
pig_cmd = [pig_bin]
75+
76+
if self.pig_properties:
77+
pig_cmd.extend(self.pig_properties)
78+
if pig_opts:
79+
pig_opts_list = pig_opts.split()
80+
pig_cmd.extend(pig_opts_list)
81+
82+
pig_cmd.extend(["-f", fname] + cmd_extra)
83+
84+
if verbose:
85+
self.log.info("%s", " ".join(pig_cmd))
86+
sub_process: Any = subprocess.Popen(
87+
pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
88+
)
89+
self.sub_process = sub_process
90+
stdout = ""
91+
for line in iter(sub_process.stdout.readline, b""):
92+
stdout += line.decode("utf-8")
93+
if verbose:
94+
self.log.info(line.strip())
95+
sub_process.wait()
8296

83-
pig_cmd.extend(["-f", fname] + cmd_extra)
97+
if sub_process.returncode:
98+
raise AirflowException(stdout)
8499

85-
if verbose:
86-
self.log.info("%s", " ".join(pig_cmd))
87-
sub_process: Any = subprocess.Popen(
88-
pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
89-
)
90-
self.sub_process = sub_process
91-
stdout = ""
92-
for line in iter(sub_process.stdout.readline, b""):
93-
stdout += line.decode("utf-8")
94-
if verbose:
95-
self.log.info(line.strip())
96-
sub_process.wait()
97-
98-
if sub_process.returncode:
99-
raise AirflowException(stdout)
100-
101-
return stdout
100+
return stdout
102101

103102
def kill(self) -> None:
104103
"""Kill Pig job."""

β€Žairflow/providers/dbt/cloud/hooks/dbt.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,13 +234,14 @@ async def get_job_details(
234234
endpoint = f"{account_id}/runs/{run_id}/"
235235
headers, tenant = await self.get_headers_tenants_from_connection()
236236
url, params = self.get_request_url_params(tenant, endpoint, include_related)
237-
async with aiohttp.ClientSession(headers=headers) as session:
238-
async with session.get(url, params=params) as response:
239-
try:
240-
response.raise_for_status()
241-
return await response.json()
242-
except ClientResponseError as e:
243-
raise AirflowException(str(e.status) + ":" + e.message)
237+
async with aiohttp.ClientSession(headers=headers) as session, session.get(
238+
url, params=params
239+
) as response:
240+
try:
241+
response.raise_for_status()
242+
return await response.json()
243+
except ClientResponseError as e:
244+
raise AirflowException(f"{e.status}:{e.message}")
244245

245246
async def get_job_status(
246247
self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None

β€Žairflow/providers/exasol/hooks/exasol.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,8 @@ def get_records(
9797
sql statements to execute
9898
:param parameters: The parameters to render the SQL query with.
9999
"""
100-
with closing(self.get_conn()) as conn:
101-
with closing(conn.execute(sql, parameters)) as cur:
102-
return cur.fetchall()
100+
with closing(self.get_conn()) as conn, closing(conn.execute(sql, parameters)) as cur:
101+
return cur.fetchall()
103102

104103
def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None) -> Any:
105104
"""Execute the SQL and return the first resulting row.
@@ -108,9 +107,8 @@ def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, An
108107
sql statements to execute
109108
:param parameters: The parameters to render the SQL query with.
110109
"""
111-
with closing(self.get_conn()) as conn:
112-
with closing(conn.execute(sql, parameters)) as cur:
113-
return cur.fetchone()
110+
with closing(self.get_conn()) as conn, closing(conn.execute(sql, parameters)) as cur:
111+
return cur.fetchone()
114112

115113
def export_to_file(
116114
self,

β€Žairflow/providers/google/cloud/hooks/gcs.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -550,10 +550,9 @@ def _call_with_retry(f: Callable[[], None]) -> None:
550550
if gzip:
551551
filename_gz = filename + ".gz"
552552

553-
with open(filename, "rb") as f_in:
554-
with gz.open(filename_gz, "wb") as f_out:
555-
shutil.copyfileobj(f_in, f_out)
556-
filename = filename_gz
553+
with open(filename, "rb") as f_in, gz.open(filename_gz, "wb") as f_out:
554+
shutil.copyfileobj(f_in, f_out)
555+
filename = filename_gz
557556

558557
_call_with_retry(
559558
partial(blob.upload_from_filename, filename=filename, content_type=mime_type, timeout=timeout)

β€Žairflow/providers/google/cloud/hooks/kubernetes_engine.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -493,19 +493,18 @@ async def delete_pod(self, name: str, namespace: str):
493493
:param name: Name of the pod.
494494
:param namespace: Name of the pod's namespace.
495495
"""
496-
async with Token(scopes=self.scopes) as token:
497-
async with self.get_conn(token) as connection:
498-
try:
499-
v1_api = async_client.CoreV1Api(connection)
500-
await v1_api.delete_namespaced_pod(
501-
name=name,
502-
namespace=namespace,
503-
body=client.V1DeleteOptions(),
504-
)
505-
except async_client.ApiException as e:
506-
# If the pod is already deleted
507-
if e.status != 404:
508-
raise
496+
async with Token(scopes=self.scopes) as token, self.get_conn(token) as connection:
497+
try:
498+
v1_api = async_client.CoreV1Api(connection)
499+
await v1_api.delete_namespaced_pod(
500+
name=name,
501+
namespace=namespace,
502+
body=client.V1DeleteOptions(),
503+
)
504+
except async_client.ApiException as e:
505+
# If the pod is already deleted
506+
if e.status != 404:
507+
raise
509508

510509
async def read_logs(self, name: str, namespace: str):
511510
"""Read logs inside the pod while starting containers inside.
@@ -518,20 +517,19 @@ async def read_logs(self, name: str, namespace: str):
518517
:param name: Name of the pod.
519518
:param namespace: Name of the pod's namespace.
520519
"""
521-
async with Token(scopes=self.scopes) as token:
522-
async with self.get_conn(token) as connection:
523-
try:
524-
v1_api = async_client.CoreV1Api(connection)
525-
logs = await v1_api.read_namespaced_pod_log(
526-
name=name,
527-
namespace=namespace,
528-
follow=False,
529-
timestamps=True,
530-
)
531-
logs = logs.splitlines()
532-
for line in logs:
533-
self.log.info("Container logs from %s", line)
534-
return logs
535-
except HTTPError:
536-
self.log.exception("There was an error reading the kubernetes API.")
537-
raise
520+
async with Token(scopes=self.scopes) as token, self.get_conn(token) as connection:
521+
try:
522+
v1_api = async_client.CoreV1Api(connection)
523+
logs = await v1_api.read_namespaced_pod_log(
524+
name=name,
525+
namespace=namespace,
526+
follow=False,
527+
timestamps=True,
528+
)
529+
logs = logs.splitlines()
530+
for line in logs:
531+
self.log.info("Container logs from %s", line)
532+
return logs
533+
except HTTPError:
534+
self.log.exception("There was an error reading the kubernetes API.")
535+
raise

0 commit comments

Comments
 (0)