Skip to content

Commit 6b3010e

Browse files
committed
feat(sftkit): improve pg introspection, prevent deletion of pg_ functions
1 parent 4d2f7ff commit 6b3010e

File tree

12 files changed

+267
-56
lines changed

12 files changed

+267
-56
lines changed

sftkit/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ source = ["sftkit"]
3434

3535
[tool.pytest.ini_options]
3636
asyncio_mode = "auto"
37+
asyncio_default_fixture_loop_scope = "session"
3738
minversion = "6.0"
3839
testpaths = ["tests"]
3940

sftkit/sftkit/database/_migrations.py

Lines changed: 34 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66

77
import asyncpg
88

9+
from sftkit.database import Connection
10+
from sftkit.database.introspection import list_views, list_triggers, list_functions, list_constraints
11+
912
logger = logging.getLogger(__name__)
1013

1114
MIGRATION_VERSION_RE = re.compile(r"^-- migration: (?P<version>\w+)$")
1215
MIGRATION_REQURES_RE = re.compile(r"^-- requires: (?P<version>\w+)$")
1316
MIGRATION_TABLE = "schema_revision"
1417

1518

16-
async def _run_postgres_code(conn: asyncpg.Connection, code: str, file_name: Path):
19+
async def _run_postgres_code(conn: Connection, code: str, file_name: Path):
1720
if all(line.startswith("--") for line in code.splitlines()):
1821
return
1922
try:
@@ -32,33 +35,23 @@ async def _run_postgres_code(conn: asyncpg.Connection, code: str, file_name: Pat
3235
raise ValueError(f"Syntax or Access error when executing SQL code ({file_name!s}): {message!r}") from exc
3336

3437

35-
async def _drop_all_views(conn: asyncpg.Connection, schema: str):
38+
async def _drop_all_views(conn: Connection, schema: str):
3639
# TODO: we might have to find out the dependency order of the views if drop cascade does not work
37-
result = await conn.fetch(
38-
"select table_name from information_schema.views where table_schema = $1 and table_name !~ '^pg_';",
39-
schema,
40-
)
41-
views = [row["table_name"] for row in result]
40+
views = await list_views(conn, schema)
4241
if len(views) == 0:
4342
return
4443

4544
# we use drop if exists here as the cascade dropping might lead the view to being already dropped
4645
# due to being a dependency of another view
47-
drop_statements = "\n".join([f"drop view if exists {view} cascade;" for view in views])
46+
drop_statements = "\n".join([f"drop view if exists {view.table_name} cascade;" for view in views])
4847
await conn.execute(drop_statements)
4948

5049

51-
async def _drop_all_triggers(conn: asyncpg.Connection, schema: str):
52-
result = await conn.fetch(
53-
"select distinct on (trigger_name, event_object_table) trigger_name, event_object_table "
54-
"from information_schema.triggers where trigger_schema = $1",
55-
schema,
56-
)
50+
async def _drop_all_triggers(conn: Connection, schema: str):
51+
triggers = await list_triggers(conn, schema)
5752
statements = []
58-
for row in result:
59-
trigger_name = row["trigger_name"]
60-
table = row["event_object_table"]
61-
statements.append(f"drop trigger {trigger_name} on {table};")
53+
for trigger in triggers:
54+
statements.append(f'drop trigger "{trigger.trigger_name}" on "{trigger.event_object_table}";')
6255

6356
if len(statements) == 0:
6457
return
@@ -67,27 +60,20 @@ async def _drop_all_triggers(conn: asyncpg.Connection, schema: str):
6760
await conn.execute(drop_statements)
6861

6962

70-
async def _drop_all_functions(conn: asyncpg.Connection, schema: str):
71-
result = await conn.fetch(
72-
"select proname, pg_get_function_identity_arguments(oid) as signature, prokind from pg_proc "
73-
"where pronamespace = $1::regnamespace;",
74-
schema,
75-
)
63+
async def _drop_all_functions(conn: Connection, schema: str):
64+
funcs = await list_functions(conn, schema)
7665
drop_statements = []
77-
for row in result:
78-
kind = row["prokind"].decode("utf-8")
79-
name = row["proname"]
80-
signature = row["signature"]
81-
if kind in ("f", "w"):
66+
for func in funcs:
67+
if func.prokind in ("f", "w"):
8268
drop_type = "function"
83-
elif kind == "a":
69+
elif func.prokind == "a":
8470
drop_type = "aggregate"
85-
elif kind == "p":
71+
elif func.prokind == "p":
8672
drop_type = "procedure"
8773
else:
88-
raise RuntimeError(f'Unknown postgres function type "{kind}"')
74+
raise RuntimeError(f'Unknown postgres function type "{func.prokind}"')
8975

90-
drop_statements.append(f"drop {drop_type} {name}({signature}) cascade;")
76+
drop_statements.append(f'drop {drop_type} "{func.proname}"({func.signature}) cascade;')
9177

9278
if len(drop_statements) == 0:
9379
return
@@ -96,37 +82,31 @@ async def _drop_all_functions(conn: asyncpg.Connection, schema: str):
9682
await conn.execute(drop_code)
9783

9884

99-
async def _drop_all_constraints(conn: asyncpg.Connection, schema: str):
85+
async def _drop_all_constraints(conn: Connection, schema: str):
10086
"""drop all constraints in the given schema which are not unique, primary or foreign key constraints"""
101-
result = await conn.fetch(
102-
"select con.conname as constraint_name, rel.relname as table_name, con.contype as constraint_type "
103-
"from pg_catalog.pg_constraint con "
104-
" join pg_catalog.pg_namespace nsp on nsp.oid = con.connamespace "
105-
" left join pg_catalog.pg_class rel on rel.oid = con.conrelid "
106-
"where nsp.nspname = $1 and con.conname !~ '^pg_' "
107-
" and con.contype != 'p' and con.contype != 'f' and con.contype != 'u';",
108-
schema,
109-
)
110-
constraints = []
111-
for row in result:
112-
constraint_name = row["constraint_name"]
113-
constraint_type = row["constraint_type"].decode("utf-8")
114-
table_name = row["table_name"]
87+
constraints = await list_constraints(conn, schema)
88+
drop_statements = []
89+
for constraint in constraints:
90+
constraint_name = constraint.conname
91+
constraint_type = constraint.contype
92+
table_name = constraint.relname
93+
if constraint_type in ("p", "f", "u"):
94+
continue
11595
if constraint_type == "c":
116-
constraints.append(f"alter table {table_name} drop constraint {constraint_name};")
96+
drop_statements.append(f'alter table "{table_name}" drop constraint "{constraint_name}";')
11797
elif constraint_type == "t":
118-
constraints.append(f"drop constraint trigger {constraint_name};")
98+
drop_statements.append(f"drop constraint trigger {constraint_name};")
11999
else:
120100
raise RuntimeError(f'Unknown constraint type "{constraint_type}" for constraint "{constraint_name}"')
121101

122-
if len(constraints) == 0:
102+
if len(drop_statements) == 0:
123103
return
124104

125-
drop_statements = "\n".join(constraints)
126-
await conn.execute(drop_statements)
105+
drop_cmd = "\n".join(drop_statements)
106+
await conn.execute(drop_cmd)
127107

128108

129-
async def _drop_db_code(conn: asyncpg.Connection, schema: str):
109+
async def _drop_db_code(conn: Connection, schema: str):
130110
await _drop_all_triggers(conn, schema=schema)
131111
await _drop_all_functions(conn, schema=schema)
132112
await _drop_all_views(conn, schema=schema)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from pydantic import BaseModel
2+
3+
from sftkit.database import Connection
4+
5+
6+
class PgFunctionDef(BaseModel):
7+
proname: str
8+
pronamespace: int # oid
9+
proowner: int # oid
10+
prolang: int # oid
11+
procost: float
12+
prorows: int
13+
provariadic: int # oid
14+
prosupport: str
15+
prokind: str
16+
prosecdef: bool
17+
proleakproof: bool
18+
proisstrict: bool
19+
proretset: bool
20+
provolatile: str
21+
proparallel: str
22+
pronargs: int
23+
pronargdefaults: int
24+
prorettype: int # oid
25+
proargtypes: list[int] # oid
26+
proallargtypes: list[int] | None # oid
27+
proargmodes: list[str] | None
28+
proargnames: list[str] | None
29+
# proargdefaults: pg_node_tree | None
30+
protrftypes: list[str] | None
31+
prosrc: str
32+
probin: str | None
33+
# prosqlbody: pg_node_tree | None
34+
proconfig: list[str] | None
35+
proacl: list[str] | None
36+
signature: str
37+
38+
39+
async def list_functions(conn: Connection, schema: str) -> list[PgFunctionDef]:
40+
return await conn.fetch_many(
41+
PgFunctionDef,
42+
"select pg_proc.*, pg_get_function_identity_arguments(oid) as signature from pg_proc "
43+
"where pronamespace = $1::regnamespace and pg_proc.proname !~ '^pg_';",
44+
schema,
45+
)
46+
47+
48+
class PgViewDef(BaseModel):
49+
table_name: str
50+
51+
52+
async def list_views(conn: Connection, schema: str) -> list[PgViewDef]:
53+
return await conn.fetch_many(
54+
PgViewDef,
55+
"select table_name from information_schema.views where table_schema = $1 and table_name !~ '^pg_';",
56+
schema,
57+
)
58+
59+
60+
class PgTriggerDef(BaseModel):
61+
trigger_name: str
62+
event_object_table: str
63+
64+
65+
async def list_triggers(conn: Connection, schema: str) -> list[PgTriggerDef]:
66+
return await conn.fetch_many(
67+
PgTriggerDef,
68+
"select distinct on (trigger_name, event_object_table) trigger_name, event_object_table "
69+
"from information_schema.triggers where trigger_schema = $1",
70+
schema,
71+
)
72+
73+
74+
class PgConstraintDef(BaseModel):
75+
conname: str
76+
relname: str
77+
contype: str
78+
79+
80+
async def list_constraints(conn: Connection, schema: str) -> list[PgConstraintDef]:
81+
return await conn.fetch_many(
82+
PgConstraintDef,
83+
"select con.conname, rel.relname, con.contype "
84+
"from pg_catalog.pg_constraint con "
85+
" join pg_catalog.pg_namespace nsp on nsp.oid = con.connamespace "
86+
" left join pg_catalog.pg_class rel on rel.oid = con.conrelid "
87+
"where nsp.nspname = $1 and con.conname !~ '^pg_';",
88+
schema,
89+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
alter table "user" add constraint username_allowlist check (name != 'exclusion');
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
create or replace function test_func(
2+
arg1 bigint,
3+
arg2 text
4+
) returns boolean as
5+
$$
6+
<<locals>> declare
7+
tmp_var double precision;
8+
begin
9+
tmp_var = arg1 > 0 and arg2 != 'bla';
10+
return tmp_var;
11+
end;
12+
$$ language plpgsql
13+
set search_path = "$user", public;
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
create or replace function user_trigger() returns trigger as
2+
$$
3+
begin
4+
return NEW;
5+
end
6+
$$ language plpgsql
7+
stable
8+
set search_path = "$user", public;
9+
10+
create trigger create_user_trigger
11+
before insert
12+
on "user"
13+
for each row
14+
execute function user_trigger();
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
create view user_with_post_count as
2+
select u.*, author_counts.count
3+
from "user" as u
4+
join (
5+
select p.author_id, count(*) as count from post as p group by p.author_id
6+
) as author_counts on u.id = author_counts.author_id;
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
-- migration: YbcL1B3z3yKj0TrH
2+
-- requires: null
3+
4+
create table "user" (
5+
id bigint primary key generated always as identity,
6+
name text not null unique,
7+
is_registered bool not null default false,
8+
comment text
9+
);
10+
11+
create table post (
12+
id bigint primary key generated always as identity,
13+
title text not null unique,
14+
author_id bigint not null references "user"(id)
15+
);

sftkit/tests/conftest.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
import random
3+
import string
4+
from pathlib import Path
5+
6+
import pytest
7+
import pytest_asyncio
8+
from pytest_asyncio import is_async_test
9+
10+
from sftkit.database import DatabaseConfig, Pool, create_db_pool, Connection, Database
11+
12+
ASSETS_DIR = Path(__file__).parent / "assets"
13+
14+
15+
@pytest.fixture(scope="session")
16+
def db_config() -> DatabaseConfig:
17+
return DatabaseConfig(
18+
host=os.environ.get("SFTKIT_TEST_DB_HOST"),
19+
port=os.environ.get("SFTKIT_TEST_DB_PORT", 5432),
20+
user=os.environ.get("SFTKIT_TEST_DB_USER"),
21+
password=os.environ.get("SFTKIT_TEST_DB_PASSWORD"),
22+
dbname=os.environ.get("SFTKIT_TEST_DB_DBNAME", "sftkit_test"),
23+
)
24+
25+
26+
@pytest_asyncio.fixture(loop_scope="session", scope="session")
27+
async def static_test_db_pool(db_config: DatabaseConfig) -> Pool:
28+
pool = await create_db_pool(cfg=db_config, n_connections=10)
29+
yield pool
30+
await pool.close()
31+
32+
33+
@pytest_asyncio.fixture(loop_scope="session", scope="function")
34+
async def test_db(db_config: DatabaseConfig, static_test_db_pool: Pool) -> Database:
35+
dbname = "".join(random.choices(string.ascii_lowercase, k=20))
36+
cfg = db_config.model_copy()
37+
cfg.dbname = dbname
38+
await static_test_db_pool.execute(f'create database "{dbname}"')
39+
if db_config.user:
40+
await static_test_db_pool.execute(f'alter database "{dbname}" owner to "{db_config.user}"')
41+
mininal_db_assets = ASSETS_DIR / "minimal_db"
42+
database = Database(
43+
config=cfg, migrations_dir=mininal_db_assets / "migrations", code_dir=mininal_db_assets / "code"
44+
)
45+
await database.apply_migrations()
46+
yield database
47+
await static_test_db_pool.execute(f'drop database "{dbname}"')
48+
49+
50+
@pytest_asyncio.fixture(loop_scope="session", scope="function")
51+
async def test_db_pool(test_db: Database) -> Pool:
52+
pool = await test_db.create_pool(n_connections=2)
53+
yield pool
54+
await pool.close()
55+
56+
57+
@pytest_asyncio.fixture(loop_scope="session", scope="function")
58+
async def test_db_conn(test_db_pool: Pool) -> Connection:
59+
async with test_db_pool.acquire() as conn:
60+
yield conn
61+
62+
63+
def pytest_collection_modifyitems(items):
64+
pytest_asyncio_tests = (item for item in items if is_async_test(item))
65+
session_scope_marker = pytest.mark.asyncio(loop_scope="session")
66+
for async_test in pytest_asyncio_tests:
67+
async_test.add_marker(session_scope_marker, append=False)

sftkit/tests/test_dummy.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

sftkit/tests/test_introspection.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from sftkit.database import Connection
2+
from sftkit.database.introspection import list_functions, list_views, list_triggers, list_constraints
3+
4+
5+
async def test_introspection_functions(test_db_conn: Connection):
6+
funcs = await list_functions(test_db_conn, "public")
7+
assert len([x for x in funcs if x.proname == "test_func"]) > 0
8+
9+
10+
async def test_introspection_view(test_db_conn: Connection):
11+
views = await list_views(test_db_conn, "public")
12+
assert len([x for x in views if x.table_name == "user_with_post_count"]) > 0
13+
14+
15+
async def test_introspection_triggers(test_db_conn: Connection):
16+
triggers = await list_triggers(test_db_conn, "public")
17+
assert len([x for x in triggers if x.trigger_name == "create_user_trigger"]) > 0
18+
19+
20+
async def test_introspection_constraints(test_db_conn: Connection):
21+
constraints = await list_constraints(test_db_conn, "public")
22+
assert len([x for x in constraints if x.conname == "username_allowlist"]) > 0

0 commit comments

Comments
 (0)