Skip to content

Commit 0f5619c

Browse files
authored
Add bulk operations utilities (#224)
1 parent cbcc8d8 commit 0f5619c

File tree

4 files changed

+258
-0
lines changed

4 files changed

+258
-0
lines changed

app/src/adapters/db/clients/postgres_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ def check_db_connection(self) -> None:
7474
# if check_migrations_current:
7575
# have_all_migrations_run(engine)
7676

77+
def get_raw_connection(self) -> sqlalchemy.PoolProxiedConnection:
78+
# For low-level operations not supported by SQLAlchemy.
79+
# Unless you specifically need this, you should use get_connection().
80+
return self._engine.raw_connection()
81+
7782

7883
def get_connection_parameters(db_config: PostgresDBConfig) -> dict[str, Any]:
7984
connect_args: dict[str, Any] = {}

app/src/db/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__all__ = ["bulk_ops"]

app/src/db/bulk_ops.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""Bulk database operations for performance.
2+
3+
Provides a bulk_upsert function for use with
4+
Postgres and the psycopg library.
5+
"""
6+
from typing import Any, Sequence
7+
8+
import psycopg
9+
from psycopg import rows, sql
10+
11+
Connection = psycopg.Connection
12+
Cursor = psycopg.Cursor
13+
kwargs_row = rows.kwargs_row
14+
15+
16+
def bulk_upsert(
17+
cur: psycopg.Cursor,
18+
table: str,
19+
attributes: Sequence[str],
20+
objects: Sequence[Any],
21+
constraint: str,
22+
update_condition: sql.SQL | None = None,
23+
) -> None:
24+
"""Bulk insert or update a sequence of objects.
25+
26+
Insert a sequence of objects, or update on conflict.
27+
Write data from one table to another.
28+
If there are conflicts due to unique constraints, overwrite existing data.
29+
30+
Args:
31+
cur: the Cursor object from the pyscopg library
32+
table: the name of the table to insert into or update
33+
attributes: a sequence of attribute names to copy from each object
34+
objects: a sequence of objects to upsert
35+
constraint: the table unique constraint to use to determine conflicts
36+
update_condition: optional WHERE clause to limit updates for a
37+
conflicting row
38+
"""
39+
if not update_condition:
40+
update_condition = sql.SQL("")
41+
42+
temp_table = f"temp_{table}"
43+
_create_temp_table(cur, temp_table=temp_table, src_table=table)
44+
_bulk_insert(cur, table=temp_table, columns=attributes, objects=objects)
45+
_write_from_table_to_table(
46+
cur,
47+
src_table=temp_table,
48+
dest_table=table,
49+
columns=attributes,
50+
constraint=constraint,
51+
update_condition=update_condition,
52+
)
53+
54+
55+
def _create_temp_table(cur: psycopg.Cursor, temp_table: str, src_table: str) -> None:
56+
"""
57+
Create table that lives only for the current transaction.
58+
Use an existing table to determine the table structure.
59+
Once the transaction is committed the temp table will be deleted.
60+
Args:
61+
temp_table: the name of the temporary table to create
62+
src_table: the name of the existing table
63+
"""
64+
cur.execute(
65+
sql.SQL(
66+
"CREATE TEMP TABLE {temp_table}\
67+
(LIKE {src_table})\
68+
ON COMMIT DROP"
69+
).format(
70+
temp_table=sql.Identifier(temp_table),
71+
src_table=sql.Identifier(src_table),
72+
)
73+
)
74+
75+
76+
def _bulk_insert(
77+
cur: psycopg.Cursor,
78+
table: str,
79+
columns: Sequence[str],
80+
objects: Sequence[Any],
81+
) -> None:
82+
"""
83+
Write data from a sequence of objects to a temp table.
84+
This function uses the PostgreSQL COPY command which is highly performant.
85+
Args:
86+
cur: the Cursor object from the pyscopg library
87+
table: the name of the temporary table
88+
columns: a sequence of column names that are attributes of each object
89+
objects: a sequence of objects with attributes defined by columns
90+
"""
91+
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns))
92+
query = sql.SQL("COPY {table}({columns}) FROM STDIN").format(
93+
table=sql.Identifier(table),
94+
columns=columns_sql,
95+
)
96+
with cur.copy(query) as copy:
97+
for obj in objects:
98+
values = [getattr(obj, column) for column in columns]
99+
copy.write_row(values)
100+
101+
102+
def _write_from_table_to_table(
103+
cur: psycopg.Cursor,
104+
src_table: str,
105+
dest_table: str,
106+
columns: Sequence[str],
107+
constraint: str,
108+
update_condition: sql.SQL | None = None,
109+
) -> None:
110+
"""
111+
Write data from one table to another.
112+
If there are conflicts due to unique constraints, overwrite existing data.
113+
Args:
114+
cur: the Cursor object from the pyscopg library
115+
src_table: the name of the table that will be copied from
116+
dest_table: the name of the table that will be written to
117+
columns: a sequence of column names to copy over
118+
constraint: the arbiter constraint to use to determine conflicts
119+
update_condition: optional WHERE clause to limit updates for a
120+
conflicting row
121+
"""
122+
if not update_condition:
123+
update_condition = sql.SQL("")
124+
125+
columns_sql = sql.SQL(",").join(map(sql.Identifier, columns))
126+
update_sql = sql.SQL(",").join(
127+
[
128+
sql.SQL("{column} = EXCLUDED.{column}").format(
129+
column=sql.Identifier(column),
130+
)
131+
for column in columns
132+
if column not in ["id", "number"]
133+
]
134+
)
135+
query = sql.SQL(
136+
"INSERT INTO {dest_table}({columns})\
137+
SELECT {columns} FROM {src_table}\
138+
ON CONFLICT ON CONSTRAINT {constraint} DO UPDATE SET {update_sql}\
139+
{update_condition}"
140+
).format(
141+
dest_table=sql.Identifier(dest_table),
142+
columns=columns_sql,
143+
src_table=sql.Identifier(src_table),
144+
constraint=sql.Identifier(constraint),
145+
update_sql=update_sql,
146+
update_condition=update_condition,
147+
)
148+
cur.execute(query)
149+
150+
151+
__all__ = ["bulk_upsert"]

app/tests/src/db/test_bulk_ops.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Tests for bulk_ops module"""
2+
import operator
3+
import random
4+
from dataclasses import dataclass
5+
6+
from psycopg import rows, sql
7+
8+
import src.adapters.db as db
9+
from src.db import bulk_ops
10+
11+
12+
@dataclass
13+
class Number:
14+
id: str
15+
num: int
16+
17+
18+
def get_random_number_object() -> Number:
19+
return Number(
20+
id=str(random.randint(1000000, 9999999)),
21+
num=random.randint(1, 10000),
22+
)
23+
24+
25+
def test_bulk_upsert(db_session: db.Session):
26+
db_client = db.PostgresDBClient()
27+
conn = db_client.get_raw_connection()
28+
29+
# Override mypy, because SQLAlchemy's DBAPICursor type doesn't specify the row_factory attribute, or that it functions as a context manager
30+
with conn.cursor(row_factory=rows.class_row(Number)) as cur: # type: ignore
31+
table = "temp_table"
32+
attributes = ["id", "num"]
33+
objects = [get_random_number_object() for i in range(100)]
34+
constraint = "temp_table_pkey"
35+
36+
# Create a table for testing bulk upsert
37+
cur.execute(
38+
sql.SQL(
39+
"CREATE TEMP TABLE {table}"
40+
"("
41+
"id TEXT NOT NULL,"
42+
"num INT,"
43+
"CONSTRAINT {constraint} PRIMARY KEY (id)"
44+
")"
45+
).format(
46+
table=sql.Identifier(table),
47+
constraint=sql.Identifier(constraint),
48+
)
49+
)
50+
51+
bulk_ops.bulk_upsert(
52+
cur,
53+
table,
54+
attributes,
55+
objects,
56+
constraint,
57+
)
58+
conn.commit()
59+
60+
# Check that all the objects were inserted
61+
cur.execute(
62+
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format(
63+
table=sql.Identifier(table)
64+
)
65+
)
66+
records = cur.fetchall()
67+
objects.sort(key=operator.attrgetter("id"))
68+
assert records == objects
69+
70+
# Now modify half of the objects
71+
updated_indexes = random.sample(range(100), 50)
72+
original_objects = [objects[i] for i in range(100) if i not in updated_indexes]
73+
updated_objects = [objects[i] for i in updated_indexes]
74+
for obj in updated_objects:
75+
obj.num = random.randint(1, 10000)
76+
77+
# And insert additional objects
78+
inserted_objects = [get_random_number_object() for i in range(50)]
79+
80+
updated_and_inserted_objects = updated_objects + inserted_objects
81+
random.shuffle(updated_objects + inserted_objects)
82+
83+
bulk_ops.bulk_upsert(
84+
cur,
85+
table,
86+
attributes,
87+
updated_and_inserted_objects,
88+
constraint,
89+
)
90+
conn.commit()
91+
92+
# Check that the existing objects were updated and new objects were inserted
93+
cur.execute(
94+
sql.SQL("SELECT id, num FROM {table} ORDER BY id ASC").format(
95+
table=sql.Identifier(table)
96+
)
97+
)
98+
records = cur.fetchall()
99+
expected_objects = original_objects + updated_and_inserted_objects
100+
expected_objects.sort(key=operator.attrgetter("id"))
101+
assert records == expected_objects

0 commit comments

Comments
 (0)