Skip to content

Commit

Permalink
[sql] Fix non-normalized queries over native protocol (#8315)
Browse files Browse the repository at this point in the history
Queries where we had to fall back to the non-normalized version were
broken under the native protocol, because we were still feeding the
arguments to postgres.

Also, make `pg_get_serial_sequence` force non-normalization, since it
ignores its arguments. It still fails if actual params are used,
though.
  • Loading branch information
msullivan authored and deepbuzin committed Feb 18, 2025
1 parent 593703f commit 09ccbe7
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 5 deletions.
6 changes: 6 additions & 0 deletions edb/edgeql/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def from_string(text: str) -> Source:
def __repr__(self):
return f'<edgeql.Source text={self._text!r}>'

def denormalized(self) -> Source:
return self


class NormalizedSource(Source):
def __init__(
Expand Down Expand Up @@ -140,6 +143,9 @@ def from_string(text: str) -> NormalizedSource:
normalized = _normalize(text)
return NormalizedSource(normalized, text, normalized.pack())

def denormalized(self) -> Source:
return Source.from_string(self._text)


def inflate_span(
source: str, span: Tuple[int, Optional[int]]
Expand Down
5 changes: 5 additions & 0 deletions edb/pgsql/parser/parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ cdef class Source:
def from_string(cls, text: str) -> Source:
return Source(text)

def denormalized(self) -> Source:
return self


cdef class NormalizedSource(Source):
def __init__(
Expand Down Expand Up @@ -383,6 +386,8 @@ cdef class NormalizedSource(Source):
serialized,
)

def denormalized(self) -> Source:
return Source.from_string(self._orig_text)

def deserialize(serialized: bytes) -> Source:
if serialized[0] == 0:
Expand Down
1 change: 1 addition & 0 deletions edb/pgsql/resolver/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def eval_FuncCall(
)

if fn_name == "pg_get_serial_sequence":
eval_list(expr.args, ctx=ctx)
# we do not expose sequences, so any calls to this function returns NULL
return pgast.NullConstant()

Expand Down
5 changes: 3 additions & 2 deletions edb/server/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def compile_sql(
disambiguate_column_names=False,
backend_runtime_params=self.state.backend_runtime_params,
protocol_version=defines.POSTGRES_PROTOCOL,
)
)[0]

def compile_serialized_request(
self,
Expand Down Expand Up @@ -2682,7 +2682,7 @@ def compile_sql_as_unit_group(
],
)

sql_units = sql.compile_sql(
sql_units, force_non_normalized = sql.compile_sql(
source,
schema=schema,
tx_state=sql_tx_state,
Expand All @@ -2702,6 +2702,7 @@ def compile_sql_as_unit_group(
qug = dbstate.QueryUnitGroup(
cardinality=sql_units[-1].cardinality,
cacheable=True,
force_non_normalized=force_non_normalized,
)

for sql_unit in sql_units:
Expand Down
2 changes: 2 additions & 0 deletions edb/server/compiler/dbstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ class QueryUnitGroup:
cache_state: int = 0
tx_seq_id: int = 0

force_non_normalized: bool = False

@property
def units(self) -> List[QueryUnit]:
if self._unpacked_units is None:
Expand Down
6 changes: 3 additions & 3 deletions edb/server/compiler/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def compile_sql(
backend_runtime_params: pg_params.BackendRuntimeParams,
protocol_version: defines.ProtocolVersion,
implicit_limit: Optional[int] = None,
) -> List[dbstate.SQLQueryUnit]:
) -> tuple[list[dbstate.SQLQueryUnit], bool]:
def _try(
q: str, normalized_params: List[int]
) -> List[dbstate.SQLQueryUnit]:
Expand All @@ -108,7 +108,7 @@ def _try(
normalized_params = list(source.extra_type_oids())
try:
try:
return _try(source.text(), normalized_params)
return _try(source.text(), normalized_params), False
except DisableNormalization:
# compiler requested non-normalized query (it needs it for static
# evaluation)
Expand All @@ -120,7 +120,7 @@ def _try(
# TODO: Can we tell the server to cache using non-extracted?
for unit in units:
unit.cacheable = False
return units
return units, True
except DisableNormalization:
pass

Expand Down
5 changes: 5 additions & 0 deletions edb/server/dbview/dbview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,8 @@ cdef class DatabaseConnectionView:
)

source = query_req.source
if query_unit_group.force_non_normalized:
source = source.denormalized()
return CompiledQuery(
query_unit_group=query_unit_group,
first_extra=source.first_extra(),
Expand Down Expand Up @@ -1464,6 +1466,9 @@ cdef class DatabaseConnectionView:

desc_map = {}
source = query_req.source
if qug.force_non_normalized:
source = source.denormalized()

first_extra = source.first_extra()
num_injected_params = 0
if qug.globals is not None:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2957,6 +2957,20 @@ async def test_sql_native_query_27(self):
select (), asdf
''')

async def test_sql_native_query_28(self):
await self.assert_sql_query_result(
"SELECT current_setting('search_path') as path",
[{"path": 'public'}],
)

await self.assert_sql_query_result(
'''
SELECT pg_get_serial_sequence('"public"."Book"', 1)
::regclass::text as seq;
''',
[{"seq": None}],
)


class TestSQLQueryNonTransactional(tb.SQLQueryTestCase):

Expand Down

0 comments on commit 09ccbe7

Please sign in to comment.