Skip to content

Commit a28c66f

Browse files
authored
[AIRFLOW-4734] Upsert functionality for PostgresHook.insert_rows() (#8625)
PostgresHook's parent class, DbApiHook, implements upsert in its insert_rows() method with the replace=True flag. However, the underlying generated SQL is specific to MySQL's "REPLACE INTO" syntax and is not applicable to PostgreSQL. This pulls out the sql generation code for insert/upsert out in to a method that is then overridden in the PostgreSQL subclass to generate the "INSERT ... ON CONFLICT DO UPDATE" syntax ("new" since Postgres 9.5)
1 parent 249e80b commit a28c66f

File tree

4 files changed

+154
-16
lines changed

4 files changed

+154
-16
lines changed

β€Žairflow/hooks/dbapi_hook.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,43 @@ def get_cursor(self):
226226
"""
227227
return self.get_conn().cursor()
228228

229+
@staticmethod
230+
def _generate_insert_sql(table, values, target_fields, replace, **kwargs):
231+
"""
232+
Static helper method that generate the INSERT SQL statement.
233+
The REPLACE variant is specific to MySQL syntax.
234+
235+
:param table: Name of the target table
236+
:type table: str
237+
:param values: The row to insert into the table
238+
:type values: tuple of cell values
239+
:param target_fields: The names of the columns to fill in the table
240+
:type target_fields: iterable of strings
241+
:param replace: Whether to replace instead of insert
242+
:type replace: bool
243+
:return: The generated INSERT or REPLACE SQL statement
244+
:rtype: str
245+
"""
246+
placeholders = ["%s", ] * len(values)
247+
248+
if target_fields:
249+
target_fields = ", ".join(target_fields)
250+
target_fields = "({})".format(target_fields)
251+
else:
252+
target_fields = ''
253+
254+
if not replace:
255+
sql = "INSERT INTO "
256+
else:
257+
sql = "REPLACE INTO "
258+
sql += "{0} {1} VALUES ({2})".format(
259+
table,
260+
target_fields,
261+
",".join(placeholders))
262+
return sql
263+
229264
def insert_rows(self, table, rows, target_fields=None, commit_every=1000,
230-
replace=False):
265+
replace=False, **kwargs):
231266
"""
232267
A generic way to insert a set of tuples into a table,
233268
a new transaction is created every commit_every rows
@@ -244,11 +279,6 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000,
244279
:param replace: Whether to replace instead of insert
245280
:type replace: bool
246281
"""
247-
if target_fields:
248-
target_fields = ", ".join(target_fields)
249-
target_fields = "({})".format(target_fields)
250-
else:
251-
target_fields = ''
252282
i = 0
253283
with closing(self.get_conn()) as conn:
254284
if self.supports_autocommit:
@@ -262,15 +292,9 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000,
262292
for cell in row:
263293
lst.append(self._serialize_cell(cell, conn))
264294
values = tuple(lst)
265-
placeholders = ["%s", ] * len(values)
266-
if not replace:
267-
sql = "INSERT INTO "
268-
else:
269-
sql = "REPLACE INTO "
270-
sql += "{0} {1} VALUES ({2})".format(
271-
table,
272-
target_fields,
273-
",".join(placeholders))
295+
sql = self._generate_insert_sql(
296+
table, values, target_fields, replace, **kwargs
297+
)
274298
cur.execute(sql, values)
275299
if commit_every and i % commit_every == 0:
276300
conn.commit()

β€Žairflow/providers/google/cloud/hooks/bigquery.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def get_client(self, project_id: Optional[str] = None, location: Optional[str] =
119119
)
120120

121121
def insert_rows(
122-
self, table: Any, rows: Any, target_fields: Any = None, commit_every: Any = 1000, replace: Any = False
122+
self, table: Any, rows: Any, target_fields: Any = None, commit_every: Any = 1000,
123+
replace: Any = False, **kwargs
123124
) -> NoReturn:
124125
"""
125126
Insertion is currently unsupported. Theoretically, you could use

β€Žairflow/providers/postgres/hooks/postgres.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,57 @@ def get_iam_token(self, conn):
180180
else:
181181
token = aws_hook.conn.generate_db_auth_token(conn.host, port, conn.login)
182182
return login, token, port
183+
184+
@staticmethod
185+
def _generate_insert_sql(table, values, target_fields, replace, **kwargs):
186+
"""
187+
Static helper method that generate the INSERT SQL statement.
188+
The REPLACE variant is specific to MySQL syntax.
189+
190+
:param table: Name of the target table
191+
:type table: str
192+
:param values: The row to insert into the table
193+
:type values: tuple of cell values
194+
:param target_fields: The names of the columns to fill in the table
195+
:type target_fields: iterable of strings
196+
:param replace: Whether to replace instead of insert
197+
:type replace: bool
198+
:param replace_index: the column or list of column names to act as
199+
index for the ON CONFLICT clause
200+
:type replace_index: str or list
201+
:return: The generated INSERT or REPLACE SQL statement
202+
:rtype: str
203+
"""
204+
placeholders = ["%s", ] * len(values)
205+
replace_index = kwargs.get("replace_index", None)
206+
207+
if target_fields:
208+
target_fields_fragment = ", ".join(target_fields)
209+
target_fields_fragment = "({})".format(target_fields_fragment)
210+
else:
211+
target_fields_fragment = ''
212+
213+
sql = "INSERT INTO {0} {1} VALUES ({2})".format(
214+
table,
215+
target_fields_fragment,
216+
",".join(placeholders))
217+
218+
if replace:
219+
if target_fields is None:
220+
raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires column names")
221+
if replace_index is None:
222+
raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires an unique index")
223+
if isinstance(replace_index, str):
224+
replace_index = [replace_index]
225+
replace_index_set = set(replace_index)
226+
227+
replace_target = [
228+
"{0} = excluded.{0}".format(col)
229+
for col in target_fields
230+
if col not in replace_index_set
231+
]
232+
sql += " ON CONFLICT ({0}) DO UPDATE SET {1}".format(
233+
", ".join(replace_index),
234+
", ".join(replace_target),
235+
)
236+
return sql

β€Žtests/providers/postgres/hooks/test_postgres.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,62 @@ def test_bulk_dump(self):
199199
results = [line.rstrip().decode("utf-8") for line in f.readlines()]
200200

201201
self.assertEqual(sorted(input_data), sorted(results))
202+
203+
@pytest.mark.backend("postgres")
204+
def test_insert_rows(self):
205+
table = "table"
206+
rows = [("hello",),
207+
("world",)]
208+
209+
self.db_hook.insert_rows(table, rows)
210+
211+
assert self.conn.close.call_count == 1
212+
assert self.cur.close.call_count == 1
213+
214+
commit_count = 2 # The first and last commit
215+
self.assertEqual(commit_count, self.conn.commit.call_count)
216+
217+
sql = "INSERT INTO {} VALUES (%s)".format(table)
218+
for row in rows:
219+
self.cur.execute.assert_any_call(sql, row)
220+
221+
@pytest.mark.backend("postgres")
222+
def test_insert_rows_replace(self):
223+
table = "table"
224+
rows = [(1, "hello",),
225+
(2, "world",)]
226+
fields = ("id", "value")
227+
228+
self.db_hook.insert_rows(
229+
table, rows, fields, replace=True, replace_index=fields[0])
230+
231+
assert self.conn.close.call_count == 1
232+
assert self.cur.close.call_count == 1
233+
234+
commit_count = 2 # The first and last commit
235+
self.assertEqual(commit_count, self.conn.commit.call_count)
236+
237+
sql = "INSERT INTO {0} ({1}, {2}) VALUES (%s,%s) " \
238+
"ON CONFLICT ({1}) DO UPDATE SET {2} = excluded.{2}".format(
239+
table, fields[0], fields[1])
240+
for row in rows:
241+
self.cur.execute.assert_any_call(sql, row)
242+
243+
@pytest.mark.xfail
244+
@pytest.mark.backend("postgres")
245+
def test_insert_rows_replace_missing_target_field_arg(self):
246+
table = "table"
247+
rows = [(1, "hello",),
248+
(2, "world",)]
249+
fields = ("id", "value")
250+
self.db_hook.insert_rows(
251+
table, rows, replace=True, replace_index=fields[0])
252+
253+
@pytest.mark.xfail
254+
@pytest.mark.backend("postgres")
255+
def test_insert_rows_replace_missing_replace_index_arg(self):
256+
table = "table"
257+
rows = [(1, "hello",),
258+
(2, "world",)]
259+
fields = ("id", "value")
260+
self.db_hook.insert_rows(table, rows, fields, replace=True)

0 commit comments

Comments
 (0)