Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ All notable changes to this project are documented in this file.

- (INTERNAL) Renamed certain functions to better convey their meaning. [#5]

### Fixed

- Bug that resulted in query parameters not being properly handled when wrapped within
`Bulk`. [#8]


## 0.1.1 [2025-11-26]

Expand All @@ -22,4 +27,5 @@ All notable changes to this project are documented in this file.


[#5]: https://github.com/manoss96/onlymaps/pull/5
[#6]: https://github.com/manoss96/onlymaps/pull/6
[#6]: https://github.com/manoss96/onlymaps/pull/6
[#8]: https://github.com/manoss96/onlymaps/pull/8
28 changes: 27 additions & 1 deletion onlymaps/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from abc import ABC
from typing import Any, Mapping, Sequence
from typing import Any, Callable, Mapping, Sequence

from pydantic import BaseModel

Expand Down Expand Up @@ -48,6 +48,32 @@ def __init__(self, obj: Sequence[Sequence[Any] | Mapping[str, Any]], /):
"""
super().__init__(obj)

def get_mapped_value(self, arg_map_fn: Callable[[Any], Any]) -> Any:
"""
Returns this `Bulk` instance's underlying value after having
each of its items go through the provided mapping function.

:param `(Any) -> Any` arg_map_fn: An argument mapping function.

"""

bulk_type = type(self.value)

if issubclass(bulk_type, (list, tuple, set)):

def handle_seq_or_map_param(p: Any) -> Any:
match p:
case dict():
return {key: arg_map_fn(val) for key, val in p.items()}
case list() | tuple() | set():
return type(p)(arg_map_fn(val) for val in p)
case _:
return p

return bulk_type(handle_seq_or_map_param(p) for p in self.value)

return self.value


class Json(_ParamWrapper):
"""
Expand Down
6 changes: 5 additions & 1 deletion onlymaps/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,14 @@ def handle_param(param: Any) -> Any:
nonlocal is_bulk
match param:
case Bulk():

if not allow_bulk:
raise ValueError("Use method `exec` for bulk statements.")

is_bulk = True
return param.value

return param.get_mapped_value(self.__driver.handle_sql_param)

case _ if is_bulk:
raise ValueError(
"Cannot provide additional parameters in `_bulk` mode."
Expand Down
88 changes: 88 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,94 @@ def test_query_on_result_type_type_error(self, db: Database) -> None: # <async>
query = SQL.SELECT_SINGLE_SCALAR(db.driver)
db.fetch_one(int, query, "a") # <await>

def test_query_on_complex_bulk_with_sequence_args( # <async>
self, db: Database
) -> None:
"""
Tests whether arguments are properly handled when wrapped
within a `Bulk` argument containing sequence arguments.
"""

if db.driver in {Driver.SQL_SERVER, Driver.ORACLE_DB}:
pytest.skip(
reason="Temporary tables not supported or need different syntax."
)

tmp_table = "tmp_table"
c1, c2 = "c1", "c2"

db.exec( # <await>
f"""
CREATE TEMPORARY TABLE {tmp_table} (
{c1} VARCHAR(100),
{c2} VARCHAR(100)
)
"""
)

class PydanticModel(BaseModel):
"""
A simple pydantic model class.
"""

n: int

params = [[Json([i]), PydanticModel(n=i)] for i in range(5)]

plchld_1 = SQL.placeholder(db.driver, n=1)
plchld_2 = SQL.placeholder(db.driver, n=2)

# Asserts no exception is raised.
db.exec( # <await>
f"INSERT INTO {tmp_table}({c1}, {c2}) VALUES({plchld_1}, {plchld_2})",
Bulk(params),
)

def test_query_on_complex_bulk_with_mapping_args( # <async>
self, db: Database
) -> None:
"""
Tests whether arguments are properly handled when wrapped
within a `Bulk` argument containing mapping arguments.
"""

if db.driver in {Driver.SQL_SERVER, Driver.ORACLE_DB}:
pytest.skip(
reason="Temporary tables not supported or need different syntax."
)

tmp_table = "tmp_table"
c1, c2 = "c1", "c2"

db.exec( # <await>
f"""
CREATE TEMPORARY TABLE {tmp_table} (
{c1} VARCHAR(100),
{c2} VARCHAR(100)
)
"""
)

class PydanticModel(BaseModel):
"""
A simple pydantic model class.
"""

n: int

params = [
{"scalar1": Json([i]), "scalar2": PydanticModel(n=i)} for i in range(5)
]

plchld_1 = SQL.kw_placeholder(db.driver, n=1)
plchld_2 = SQL.kw_placeholder(db.driver, n=2)

# Asserts no exception is raised.
db.exec( # <await>
f"INSERT INTO {tmp_table}({c1}, {c2}) VALUES({plchld_1}, {plchld_2})",
Bulk(params),
)

@pytest.mark.parametrize("method", ["fetch_one", "fetch_one_or_none", "fetch_many"])
def test_query_on_bulk_param(self, db: Database, method: str) -> None: # <async>
"""
Expand Down
24 changes: 12 additions & 12 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ class SQL:
CREATE_TEST_TABLE = f"CREATE TABLE {TEST_TABLE} (id INT PRIMARY KEY)"

@staticmethod
def _placeholder(driver: Driver, n: int | None = None) -> str:
def placeholder(driver: Driver, n: int | None = None) -> str:
"""
Returns a positional placeholder based on the provided driver.

Expand All @@ -339,22 +339,22 @@ def _placeholder(driver: Driver, n: int | None = None) -> str:
return "%s"

@staticmethod
def _kw_placeholder(driver: Driver) -> str:
def kw_placeholder(driver: Driver, n: int | None = None) -> str:
"""
Returns a keyword placeholder based on the provided driver.
"""
match driver:
case Driver.ORACLE_DB | Driver.SQL_LITE:
return ":scalar"
return f":scalar{n if n is not None else ''}"
case _:
return "%(scalar)s"
return f"%(scalar{n if n is not None else ''})s"

@classmethod
def SELECT_SINGLE_SCALAR(cls, driver: Driver) -> str:
"""
Query to select a single scalar.
"""
placeholder = cls._placeholder(driver)
placeholder = cls.placeholder(driver)
return f"SELECT {placeholder}"

@classmethod
Expand All @@ -364,15 +364,15 @@ def SELECT_SINGLE_ROW(cls, driver: Driver) -> str:
"""
query = "SELECT "
for i, field_name in enumerate(RowPydanticModel.model_fields):
query += f"{cls._placeholder(driver, i)} AS {field_name},"
query += f"{cls.placeholder(driver, i)} AS {field_name},"
return query.removesuffix(",")

@classmethod
def SELECT_MULTIPLE_SCALAR(cls, driver: Driver) -> str:
"""
Query to select multiple scalars.
"""
placeholder = cls._kw_placeholder(driver)
placeholder = cls.kw_placeholder(driver)
return f"SELECT {placeholder} UNION ALL SELECT {placeholder}"

@classmethod
Expand All @@ -387,7 +387,7 @@ def build_query() -> str:
nonlocal idx
query = "SELECT "
for field_name in RowPydanticModel.model_fields:
query += f"{cls._placeholder(driver, idx)} AS {field_name},"
query += f"{cls.placeholder(driver, idx)} AS {field_name},"
idx += 1
return query.removesuffix(",")

Expand All @@ -406,15 +406,15 @@ def SELECT_SINGLE_ROW_WITH_INT_COL_NAMES(
"""
query = "SELECT "
for i in range(num_placeholders):
query += f"{cls._placeholder(driver, i)} AS c{i},"
query += f"{cls.placeholder(driver, i)} AS c{i},"
return query.removesuffix(",")

@classmethod
def INSERT_INTO_TEST_TABLE(cls, driver: Driver, returning_id: bool = False) -> str:
"""
Query to insert row into the test table.
"""
placeholder = cls._placeholder(driver)
placeholder = cls.placeholder(driver)
query = f"INSERT INTO {cls.TEST_TABLE} VALUES ({placeholder})"
if returning_id:
assert driver == Driver.POSTGRES
Expand All @@ -426,15 +426,15 @@ def SELECT_FROM_TEST_TABLE(cls, driver: Driver) -> str:
"""
Query to select one row from the test table.
"""
placeholder = cls._placeholder(driver)
placeholder = cls.placeholder(driver)
return f"SELECT id FROM {cls.TEST_TABLE} WHERE id = {placeholder}"

@classmethod
def SELECT_MANY_FROM_TEST_TABLE(cls, driver: Driver, num_elements: int) -> str:
"""
Query to select many rows from the test table.
"""
template = ",".join(cls._placeholder(driver, i) for i in range(num_elements))
template = ",".join(cls.placeholder(driver, i) for i in range(num_elements))
return f"SELECT id FROM {cls.TEST_TABLE} WHERE id IN ({template})"

@classmethod
Expand Down