Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sqlalchemy-repo): #2221 - Filters not available in exists() #2228

Merged
merged 5 commits into from
Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion litestar/contrib/repository/abc/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ async def delete_many(self, item_ids: list[Any]) -> list[T]:
"""

@abstractmethod
async def exists(self, **kwargs: Any) -> bool:
async def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool:
"""Return true if the object specified by ``kwargs`` exists.

Args:
*filters: Types for specific filtering operations.
**kwargs: Identifier of the instance to be retrieved.

Returns:
Expand Down
3 changes: 2 additions & 1 deletion litestar/contrib/repository/abc/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ def delete_many(self, item_ids: list[Any]) -> list[T]:
"""

@abstractmethod
def exists(self, **kwargs: Any) -> bool:
def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool:
"""Return true if the object specified by ``kwargs`` exists.

Args:
*filters: Types for specific filtering operations.
**kwargs: Identifier of the instance to be retrieved.

Returns:
Expand Down
10 changes: 6 additions & 4 deletions litestar/contrib/repository/testing/generic_mock_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,18 @@ async def delete_many(self, item_ids: list[Any]) -> list[ModelT]:
instances.append(obj)
return instances

async def exists(self, **kwargs: Any) -> bool:
async def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool:
"""Return true if the object specified by ``kwargs`` exists.

Args:
*filters: Types for specific filtering operations.
**kwargs: Identifier of the instance to be retrieved.

Returns:
True if the instance was found. False if not found..

"""
existing = await self.get_one_or_none(**kwargs)
existing = await self.count(*filters, **kwargs)
return bool(existing)

async def get(self, item_id: Any, **kwargs: Any) -> ModelT:
Expand Down Expand Up @@ -533,17 +534,18 @@ def delete_many(self, item_ids: list[Any]) -> list[ModelT]:
instances.append(obj)
return instances

def exists(self, **kwargs: Any) -> bool:
def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool:
"""Return true if the object specified by ``kwargs`` exists.

Args:
*filters: Types for specific filtering operations.
**kwargs: Identifier of the instance to be retrieved.

Returns:
True if the instance was found. False if not found..

"""
existing = self.get_one_or_none(**kwargs)
existing = self.count(*filters, **kwargs)
return bool(existing)

def get(self, item_id: Any, **kwargs: Any) -> ModelT:
Expand Down
5 changes: 3 additions & 2 deletions litestar/contrib/sqlalchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,18 @@ async def delete_many(
def _get_insertmanyvalues_max_parameters(self, chunk_size: int | None = None) -> int:
return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS

async def exists(self, **kwargs: Any) -> bool:
async def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool:
"""Return true if the object specified by ``kwargs`` exists.

Args:
*filters: Types for specific filtering operations.
**kwargs: Identifier of the instance to be retrieved.

Returns:
True if the instance was found. False if not found..

"""
existing = await self.count(**kwargs)
existing = await self.count(*filters, **kwargs)
return existing > 0

def _get_base_stmt(
Expand Down
5 changes: 3 additions & 2 deletions litestar/contrib/sqlalchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,18 @@ def delete_many(
def _get_insertmanyvalues_max_parameters(self, chunk_size: int | None = None) -> int:
return chunk_size if chunk_size is not None else DEFAULT_INSERTMANYVALUES_MAX_PARAMETERS

def exists(self, **kwargs: Any) -> bool:
def exists(self, *filters: FilterTypes, **kwargs: Any) -> bool:
"""Return true if the object specified by ``kwargs`` exists.

Args:
*filters: Types for specific filtering operations.
**kwargs: Identifier of the instance to be retrieved.

Returns:
True if the instance was found. False if not found..

"""
existing = self.count(**kwargs)
existing = self.count(*filters, **kwargs)
return existing > 0

def _get_base_stmt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy.orm import Mapped, mapped_column

from litestar.contrib.repository.exceptions import ConflictError, RepositoryError
from litestar.contrib.repository.filters import LimitOffset
from litestar.contrib.repository.testing.generic_mock_repository import (
GenericAsyncMockRepository,
GenericSyncMockRepository,
Expand Down Expand Up @@ -330,6 +331,21 @@ async def test_exists(
assert exists


async def test_exists_with_filter(
repository_type: type[GenericAsyncMockRepository], create_audit_model_type: CreateAuditModelFixture
) -> None:
"""Test that the repository exists returns booleans. with filter argument"""
limit_filter = LimitOffset(limit=1, offset=0)

Model = create_audit_model_type({"random_column": Mapped[str]})

instances = [Model(random_column="value 1"), Model(random_column="value 2")]
mock_repo = repository_type[Model]() # type: ignore[index]
_ = await maybe_async(mock_repo.add_many(instances))
exists = await maybe_async(mock_repo.exists(limit_filter, random_column="value 1"))
assert exists


async def test_count(repository_type: type[GenericAsyncMockRepository], audit_model_type: AuditModelType) -> None:
"""Test that the repository count returns the total record count."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,22 @@ async def test_sqlalchemy_repo_exists(
mock_repo.session.commit.assert_not_called()


async def test_sqlalchemy_repo_exists_with_filter(
mock_repo: SQLAlchemyAsyncRepository,
monkeypatch: MonkeyPatch,
mock_repo_execute: AnyMock,
mock_repo_count: AnyMock,
) -> None:
"""Test expected method calls for exists operation. with filter argument"""
limit_filter = LimitOffset(limit=1, offset=0)
mock_repo_count.return_value = 1

exists = await maybe_async(mock_repo.exists(limit_filter, id="my-id"))

assert exists
mock_repo.session.commit.assert_not_called()


async def test_sqlalchemy_repo_count(
mock_repo: SQLAlchemyAsyncRepository,
monkeypatch: MonkeyPatch,
Expand Down