Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging connect kwargs and init kwargs with priority #31

Merged
merged 7 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion propan/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from unittest.mock import Mock

__version__ = "0.1.2.10"
__version__ = "0.1.2.11"


INSTALL_MESSAGE = (
Expand Down
17 changes: 12 additions & 5 deletions propan/brokers/_model/broker_usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from fast_depends.construct import get_dependant
from fast_depends.model import Dependant
from fast_depends.utils import args_to_kwargs
from pydantic.fields import ModelField
from typing_extensions import Self

Expand Down Expand Up @@ -46,7 +47,7 @@
Wrapper,
)
from propan.utils import apply_types, context
from propan.utils.functions import to_async
from propan.utils.functions import get_function_arguments, to_async

T = TypeVar("T")

Expand Down Expand Up @@ -83,13 +84,19 @@ def __init__(

async def connect(self, *args: Any, **kwargs: Any) -> Any:
if self._connection is None:
_args = args or self._connection_args
_kwargs = kwargs or self._connection_kwargs
self._connection = await self._connect(*_args, **_kwargs)
arguments = get_function_arguments(self.__init__) # type: ignore
init_kwargs = args_to_kwargs(
arguments,
*self._connection_args,
**self._connection_kwargs,
)
connect_kwargs = args_to_kwargs(arguments, *args, **kwargs)
_kwargs = {**init_kwargs, **connect_kwargs}
self._connection = await self._connect(**_kwargs)
return self._connection

@abstractmethod
async def _connect(self, *args: Any, **kwargs: Any) -> Any:
async def _connect(self, **kwargs: Any) -> Any:
raise NotImplementedError()

@abstractmethod
Expand Down
2 changes: 0 additions & 2 deletions propan/brokers/kafka/kafka_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,9 @@ def __init__(

async def _connect(
self,
bootstrap_servers: Union[str, List[str]] = "localhost",
**kwargs: Any,
) -> AIOKafkaConsumer:
kwargs["client_id"] = kwargs.get("client_id", "propan-" + __version__)
kwargs["bootstrap_servers"] = bootstrap_servers

producer = AIOKafkaProducer(**kwargs)
context.set_global("producer", producer)
Expand Down
2 changes: 1 addition & 1 deletion propan/brokers/kafka/kafka_broker.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class KafkaBroker(BrokerUsecase):
) -> None: ...
async def connect(
self,
bootstrap_servers: Union[str, List[str]] = "localhost",
*,
bootstrap_servers: Union[str, List[str]] = "localhost",
# both
loop: Optional[AbstractEventLoop] = None,
client_id: str = "propan-" + __version__,
Expand Down
10 changes: 5 additions & 5 deletions propan/brokers/nats/nats_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ class NatsBroker(BrokerUsecase):

def __init__(
self,
*args: Any,
servers: Union[str, List[str]] = ["nats://localhost:4222"], # noqa: B006
*,
log_fmt: Optional[str] = None,
**kwargs: AnyDict,
):
super().__init__(*args, log_fmt=log_fmt, **kwargs)
) -> None:
super().__init__(servers, log_fmt=log_fmt, **kwargs)

self._connection = None

Expand All @@ -42,7 +43,7 @@ def __init__(

async def _connect(
self,
*args: Any,
*,
url: Optional[str] = None,
error_cb: Optional[ErrorCallback] = None,
reconnected_cb: Optional[Callback] = None,
Expand All @@ -51,7 +52,6 @@ async def _connect(
if url is not None:
kwargs["servers"] = kwargs.pop("servers", []) + [url]
return await nats.connect(
*args,
error_cb=self.log_connection_broken(error_cb),
reconnected_cb=self.log_reconnected(reconnected_cb),
**kwargs,
Expand Down
3 changes: 2 additions & 1 deletion propan/brokers/nats/nats_broker.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class NatsBroker(BrokerUsecase):
def __init__(
self,
servers: Union[str, List[str]] = ["nats://localhost:4222"], # noqa: B006
*,
error_cb: Optional[ErrorCallback] = None,
disconnected_cb: Optional[Callback] = None,
closed_cb: Optional[Callback] = None,
Expand Down Expand Up @@ -68,14 +69,14 @@ class NatsBroker(BrokerUsecase):
inbox_prefix: Union[str, bytes] = DEFAULT_INBOX_PREFIX,
pending_size: int = DEFAULT_PENDING_SIZE,
flush_timeout: Optional[float] = None,
*,
logger: Optional[logging.Logger] = access_logger,
log_level: int = logging.INFO,
log_fmt: Optional[str] = None,
apply_types: bool = True,
) -> None: ...
async def connect(
self,
*,
servers: Union[str, List[str]] = ["nats://localhost:4222"], # noqa: B006
error_cb: Optional[ErrorCallback] = None,
disconnected_cb: Optional[Callback] = None,
Expand Down
13 changes: 7 additions & 6 deletions propan/brokers/rabbit/rabbit_broker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from uuid import uuid4

import aio_pika
import aiormq
from aio_pika.abc import DeliveryMode
from yarl import URL

from propan.brokers._model import BrokerUsecase
from propan.brokers._model.schemas import PropanMessage
Expand All @@ -29,12 +30,13 @@ class RabbitBroker(BrokerUsecase):

def __init__(
self,
*args: Tuple[Any, ...],
consumers: Optional[int] = None,
url: Union[str, URL, None] = None,
*,
log_fmt: Optional[str] = None,
consumers: Optional[int] = None,
**kwargs: AnyDict,
) -> None:
super().__init__(*args, log_fmt=log_fmt, **kwargs)
super().__init__(url, log_fmt=log_fmt, **kwargs)
self._max_consumers = consumers

self._channel = None
Expand All @@ -53,11 +55,10 @@ async def close(self) -> None:

async def _connect(
self,
*args: Any,
**kwargs: Any,
) -> aio_pika.RobustConnection:
connection = await aio_pika.connect_robust(
*args, **kwargs, loop=asyncio.get_event_loop()
**kwargs, loop=asyncio.get_event_loop()
)

if self._channel is None: # pragma: no branch
Expand Down
3 changes: 2 additions & 1 deletion propan/brokers/rabbit/rabbit_broker.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class RabbitBroker(BrokerUsecase):
def __init__(
self,
url: Union[str, URL, None] = None,
*,
host: str = "localhost",
port: int = 5672,
login: str = "guest",
Expand All @@ -40,7 +41,6 @@ class RabbitBroker(BrokerUsecase):
ssl_context: Optional[SSLContext] = None,
timeout: aio_pika.abc.TimeoutType = None,
client_properties: Optional[FieldTable] = None,
*,
logger: Optional[logging.Logger] = access_logger,
log_level: int = logging.INFO,
log_fmt: Optional[str] = None,
Expand Down Expand Up @@ -76,6 +76,7 @@ class RabbitBroker(BrokerUsecase):
"""
async def connect(
self,
*,
url: Union[str, URL, None] = None,
host: str = "localhost",
port: int = 5672,
Expand Down
14 changes: 2 additions & 12 deletions propan/brokers/redis/redis_broker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from functools import wraps
from typing import Any, Callable, Coroutine, Dict, List, NoReturn, Optional, TypeVar
from typing import Any, Callable, Dict, List, NoReturn, Optional, TypeVar
from uuid import uuid4

from redis.asyncio.client import PubSub, Redis
Expand Down Expand Up @@ -32,8 +32,8 @@ class RedisBroker(BrokerUsecase):
def __init__(
self,
url: str = "redis://localhost:6379",
polling_interval: float = 1.0,
*,
polling_interval: float = 1.0,
log_fmt: Optional[str] = None,
**kwargs: Any,
) -> None:
Expand All @@ -51,16 +51,6 @@ async def _connect(
pool = ConnectionPool(**url_options)
return Redis(connection_pool=pool)

async def connect(
self,
url: Optional[str] = None,
*args: Any,
**kwargs: Any,
) -> Coroutine[Any, Any, Any]:
if url is not None:
kwargs["url"] = url
return await super().connect(*args, **kwargs)

async def close(self) -> None:
for h in self.handlers:
if h.task is not None: # pragma: no branch
Expand Down
3 changes: 2 additions & 1 deletion propan/brokers/redis/redis_broker.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class RedisBroker(BrokerUsecase):
def __init__(
self,
url: str = "redis://localhost:6379",
polling_interval: float = 1.0,
*,
polling_interval: float = 1.0,
host: str = "localhost",
port: Union[str, int] = 6379,
username: Optional[str] = None,
Expand Down Expand Up @@ -70,6 +70,7 @@ class RedisBroker(BrokerUsecase):
"""
async def connect(
self,
*,
url: str = "redis://localhost:6379",
host: str = "localhost",
port: Union[str, int] = 6379,
Expand Down
2 changes: 1 addition & 1 deletion propan/brokers/sqs/sqs_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.response_queue = response_queue
self.response_callbacks = {}

async def _connect(self, url: Optional[str] = None, **kwargs: Any) -> AioBaseClient:
async def _connect(self, *, url: str, **kwargs: Any) -> AioBaseClient:
session = get_session()
client: AioBaseClient = await session._create_client(
service_name="sqs", endpoint_url=url, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion propan/brokers/sqs/sqs_broker.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class SQSBroker(BrokerUsecase):
""""""
async def connect(
self,
url: str = "http://localhost:9324/",
*,
url: str = "http://localhost:9324/",
region_name: Optional[str] = None,
api_version: Optional[str] = None,
use_ssl: bool = True,
Expand Down
17 changes: 15 additions & 2 deletions propan/utils/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from functools import wraps
from typing import Awaitable, Callable, TypeVar, cast
from typing import Awaitable, Callable, List, TypeVar, cast

from fast_depends.injector import run_async as call_or_await
from typing_extensions import ParamSpec
Expand All @@ -9,7 +10,6 @@
"to_async",
)


T = TypeVar("T")
P = ParamSpec("P")

Expand All @@ -21,3 +21,16 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return cast(T, r)

return wrapper


def get_function_arguments(func: Callable[P, T]) -> List[str]:
signature = inspect.signature(func)

arg_kinds = [
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
]

return [
param.name for param in signature.parameters.values() if param.kind in arg_kinds
]
8 changes: 8 additions & 0 deletions tests/brokers/base/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,11 @@ async def test_connect_by_url_priority(self, settings):
assert await broker.connect(*args, **kwargs)
assert await self.ping(broker)
await broker.close()

@pytest.mark.asyncio
async def test_connect_merge_args_and_kwargs(self, settings):
args, kwargs = self.get_broker_args(settings)
broker = self.broker(*args)
assert await broker.connect(**kwargs)
assert await self.ping(broker)
await broker.close()
6 changes: 6 additions & 0 deletions tests/brokers/kafka/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@
@pytest.mark.kafka
class TestKafkaConnect(BrokerConnectionTestcase):
broker = KafkaBroker

@pytest.mark.asyncio
async def test_connect_merge_args_and_kwargs(self, settings):
broker = self.broker("fake-url") # will be ignored
assert await broker.connect(bootstrap_servers=settings.url)
await broker.close()
6 changes: 6 additions & 0 deletions tests/brokers/nats/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@
@pytest.mark.nats
class TestNatsConnect(BrokerConnectionTestcase):
broker = NatsBroker

@pytest.mark.asyncio
async def test_connect_merge_args_and_kwargs(self, settings):
broker = self.broker("fake-url") # will be ignored
assert await broker.connect(servers=settings.url)
await broker.close()
17 changes: 17 additions & 0 deletions tests/brokers/rabbit/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,20 @@ async def test_connection_by_params(self, settings):
port=settings.port,
)
await broker.close()

@pytest.mark.asyncio
async def test_connect_merge_kwargs_with_priority(self, settings):
broker = self.broker(host="fake-host", port=5677) # kwargs will be ignored
assert await broker.connect(
host=settings.host,
login=settings.login,
password=settings.password,
port=settings.port,
)
await broker.close()

@pytest.mark.asyncio
async def test_connect_merge_args_and_kwargs(self, settings):
broker = self.broker("fake-url") # will be ignored
assert await broker.connect(url=settings.url)
await broker.close()
15 changes: 15 additions & 0 deletions tests/brokers/redis/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,18 @@ async def test_init_connect_by_raw_data(self, settings):
port=settings.port,
) as broker:
assert await self.ping(broker)

@pytest.mark.asyncio
async def test_connect_merge_kwargs_with_priority(self, settings):
broker = self.broker(host="fake-host", port=6377) # kwargs will be ignored
assert await broker.connect(
host=settings.host,
port=settings.port,
)
await broker.close()

@pytest.mark.asyncio
async def test_connect_merge_args_and_kwargs(self, settings):
broker = self.broker("fake-url") # will be ignored
assert await broker.connect(url=settings.url)
await broker.close()
7 changes: 7 additions & 0 deletions tests/brokers/sqs/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ def get_broker_args(self, settings):
"region_name": settings.region_name,
"config": AioConfig(signature_version=UNSIGNED),
}

@pytest.mark.asyncio
async def test_connect_merge_args_and_kwargs(self, settings):
args, kwargs = self.get_broker_args(settings)
broker = self.broker("fake-url") # will be ignored
assert await broker.connect(url=settings.url, **kwargs)
await broker.close()