Skip to content

Correct query return types #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions db_wrapper/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations
from typing import (
cast,
Any,
TypeVar,
Union,
Expand All @@ -10,8 +11,8 @@
List,
Dict)

import aiopg # type: ignore
from psycopg2.extras import register_uuid
import aiopg
from psycopg2.extras import register_uuid, RealDictRow
from psycopg2 import sql

from db_wrapper.connection import ConnectionParameters, connect
Expand All @@ -20,10 +21,6 @@
register_uuid()


# Generic doesn't need a more descriptive name
# pylint: disable=invalid-name
T = TypeVar('T')

Query = Union[str, sql.Composed]


Expand Down Expand Up @@ -57,6 +54,11 @@ async def _execute_query(
query: Query,
params: Optional[Dict[Hashable, Any]] = None,
) -> None:
# aiopg type is incorrect & thinks execute only takes str
# when in the query is passed through to psycopg2's
# cursor.execute which does accept sql.Composed objects.
query = cast(str, query)

if params:
await cursor.execute(query, params)
else:
Expand Down Expand Up @@ -88,7 +90,7 @@ async def execute_and_return(
self,
query: Query,
params: Optional[Dict[Hashable, Any]] = None,
) -> List[T]:
) -> List[RealDictRow]:
"""Execute the given SQL query & return the result.

Arguments:
Expand All @@ -102,5 +104,5 @@ async def execute_and_return(
async with self._connection.cursor() as cursor:
await self._execute_query(cursor, query, params)

result: List[T] = await cursor.fetchall()
result: List[RealDictRow] = await cursor.fetchall()
return result
20 changes: 8 additions & 12 deletions db_wrapper/client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from __future__ import annotations
from typing import (
Any,
TypeVar,
Union,
Optional,
Dict,
Hashable,
List,
Dict)
Optional,
Union,
)

from psycopg2.extras import register_uuid
from psycopg2.extras import register_uuid, RealDictRow
from psycopg2 import sql
# pylint can't seem to find the items in psycopg2 despite being available
from psycopg2._psycopg import cursor # pylint: disable=no-name-in-module
Expand All @@ -24,10 +24,6 @@
register_uuid()


# Generic doesn't need a more descriptive name
# pylint: disable=invalid-name
T = TypeVar('T')

Query = Union[str, sql.Composed]


Expand Down Expand Up @@ -60,7 +56,7 @@ def _execute_query(
params: Optional[Dict[Hashable, Any]] = None,
) -> None:
if params:
db_cursor.execute(query, params) # type: ignore
db_cursor.execute(query, params)
else:
db_cursor.execute(query)

Expand Down Expand Up @@ -88,7 +84,7 @@ def execute_and_return(
self,
query: Query,
params: Optional[Dict[Hashable, Any]] = None,
) -> List[T]:
) -> List[RealDictRow]:
"""Execute the given SQL query & return the result.

Arguments:
Expand All @@ -102,5 +98,5 @@ def execute_and_return(
with self._connection.cursor() as db_cursor:
self._execute_query(db_cursor, query, params)

result: List[T] = db_cursor.fetchall()
result: List[RealDictRow] = db_cursor.fetchall()
return result
1 change: 1 addition & 0 deletions db_wrapper/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Convenience objects to simplify database interactions w/ given interface."""

from psycopg2.extras import RealDictRow
from .async_model import (
AsyncModel,
AsyncCreate,
Expand Down
97 changes: 65 additions & 32 deletions db_wrapper/model/async_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""Asynchronous Model objects."""

from typing import Any, Dict, List
from typing import Any, Dict, List, Type
from uuid import UUID

from psycopg2.extras import RealDictRow

from db_wrapper.client import AsyncClient
from .base import (
ensure_exactly_one,
sql,
T,
CreateABC,
ReadABC,
UpdateABC,
DeleteABC,
ModelABC,
sql,
)


Expand All @@ -23,16 +25,22 @@ class AsyncCreate(CreateABC[T]):

_client: AsyncClient

def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
super().__init__(table)
def __init__(
self,
client: AsyncClient,
table: sql.Composable,
return_constructor: Type[T]
) -> None:
super().__init__(table, return_constructor)
self._client = client

async def one(self, item: T) -> T:
"""Create one new record with a given item."""
result: List[T] = await self._client.execute_and_return(
self._query_one(item))
query_result: List[RealDictRow] = \
await self._client.execute_and_return(self._query_one(item))
result: T = self._return_constructor(**query_result[0])

return result[0]
return result


class AsyncRead(ReadABC[T]):
Expand All @@ -42,19 +50,27 @@ class AsyncRead(ReadABC[T]):

_client: AsyncClient

def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
super().__init__(table)
def __init__(
self,
client: AsyncClient,
table: sql.Composable,
return_constructor: Type[T]
) -> None:
super().__init__(table, return_constructor)
self._client = client

async def one_by_id(self, id_value: UUID) -> T:
"""Read a row by it's id."""
result: List[T] = await self._client.execute_and_return(
self._query_one_by_id(id_value))
query_result: List[RealDictRow] = \
await self._client.execute_and_return(
self._query_one_by_id(id_value))

# Should only return one item from DB
ensure_exactly_one(result)
ensure_exactly_one(query_result)

result: T = self._return_constructor(**query_result[0])

return result[0]
return result


class AsyncUpdate(UpdateABC[T]):
Expand All @@ -64,11 +80,16 @@ class AsyncUpdate(UpdateABC[T]):

_client: AsyncClient

def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
super().__init__(table)
def __init__(
self,
client: AsyncClient,
table: sql.Composable,
return_constructor: Type[T]
) -> None:
super().__init__(table, return_constructor)
self._client = client

async def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T:
async def one_by_id(self, id_value: UUID, changes: Dict[str, Any]) -> T:
"""Apply changes to row with given id.

Arguments:
Expand All @@ -79,12 +100,14 @@ async def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T:
Returns:
full value of row updated
"""
result: List[T] = await self._client.execute_and_return(
self._query_one_by_id(id_value, changes))
query_result: List[RealDictRow] = \
await self._client.execute_and_return(
self._query_one_by_id(id_value, changes))

ensure_exactly_one(result)
ensure_exactly_one(query_result)
result: T = self._return_constructor(**query_result[0])

return result[0]
return result


class AsyncDelete(DeleteABC[T]):
Expand All @@ -94,19 +117,26 @@ class AsyncDelete(DeleteABC[T]):

_client: AsyncClient

def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
super().__init__(table)
def __init__(
self,
client: AsyncClient,
table: sql.Composable,
return_constructor: Type[T]
) -> None:
super().__init__(table, return_constructor)
self._client = client

async def one_by_id(self, id_value: str) -> T:
"""Delete one record with matching ID."""
result: List[T] = await self._client.execute_and_return(
self._query_one_by_id(id_value))
query_result: List[RealDictRow] = \
await self._client.execute_and_return(
self._query_one_by_id(id_value))

# Should only return one item from DB
ensure_exactly_one(result)
ensure_exactly_one(query_result)
result = self._return_constructor(**query_result[0])

return result[0]
return result


class AsyncModel(ModelABC[T]):
Expand All @@ -122,19 +152,22 @@ class AsyncModel(ModelABC[T]):
_update: AsyncUpdate[T]
_delete: AsyncDelete[T]

# PENDS python 3.9 support in pylint
# pylint: disable=unsubscriptable-object
def __init__(
self,
client: AsyncClient,
table: str,
return_constructor: Type[T],
) -> None:
super().__init__(client, table)

self._create = AsyncCreate[T](self.client, self.table)
self._read = AsyncRead[T](self.client, self.table)
self._update = AsyncUpdate[T](self.client, self.table)
self._delete = AsyncDelete[T](self.client, self.table)
self._create = AsyncCreate[T](
self.client, self.table, return_constructor)
self._read = AsyncRead[T](
self.client, self.table, return_constructor)
self._update = AsyncUpdate[T](
self.client, self.table, return_constructor)
self._delete = AsyncDelete[T](
self.client, self.table, return_constructor)

@property
def create(self) -> AsyncCreate[T]:
Expand Down
Loading