Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/test_attr_getter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import random
from dataclasses import dataclass, field

import pytest

from that_depends import providers
from that_depends.providers.attr_getter import _get_value_from_object_by_dotted_path


@dataclass
class Nested2:
some_const = 144


@dataclass
class Nested1:
nested2_attr: Nested2 = field(default_factory=Nested2)


@dataclass
class Settings:
some_str_value: str = "some_string_value"
some_int_value: int = 3453621
nested1_attr: Nested1 = field(default_factory=Nested1)


@dataclass
class NestingTestDTO: ...


@pytest.fixture()
def some_settings_provider() -> providers.Singleton[Settings]:
return providers.Singleton(Settings)


def test_attr_getter_with_zero_attribute_depth(some_settings_provider: providers.Singleton[Settings]) -> None:
attr_getter = some_settings_provider.some_str_value
assert attr_getter.sync_resolve() == Settings().some_str_value


def test_attr_getter_with_more_than_zero_attribute_depth(some_settings_provider: providers.Singleton[Settings]) -> None:
attr_getter = some_settings_provider.nested1_attr.nested2_attr.some_const
assert attr_getter.sync_resolve() == Nested2().some_const


@pytest.mark.parametrize(
("field_count", "test_field_name", "test_value"),
[(1, "test_field", "sdf6fF^SF(FF*4ffsf"), (5, "nested_field", -252625), (50, "50_lvl_field", 909234235)],
)
def test_nesting_levels(field_count: int, test_field_name: str, test_value: str | int) -> None:
obj = NestingTestDTO()
fields = [f"field_{i}" for i in range(1, field_count + 1)]
random.shuffle(fields)

attr_path = ".".join(fields) + f".{test_field_name}"
obj_copy = obj

while fields:
field_name = fields.pop(0)
setattr(obj_copy, field_name, NestingTestDTO())
obj_copy = obj_copy.__getattribute__(field_name)

setattr(obj_copy, test_field_name, test_value)

attr_value = _get_value_from_object_by_dotted_path(obj, attr_path)
assert attr_value == test_value
22 changes: 18 additions & 4 deletions that_depends/providers/attr_getter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
from operator import attrgetter

from that_depends.providers.base import AbstractProvider

Expand All @@ -7,15 +8,28 @@
P = typing.ParamSpec("P")


def _get_value_from_object_by_dotted_path(obj: typing.Any, path: str) -> typing.Any: # noqa: ANN401
attribute_getter = attrgetter(path)
return attribute_getter(obj)


class AttrGetter(AbstractProvider[T]):
__slots__ = "_provider", "_attr_name"
__slots__ = "_provider", "_attrs"

def __init__(self, provider: AbstractProvider[T], attr_name: str) -> None:
self._provider = provider
self._attr_name = attr_name
self._attrs = [attr_name]

def __getattr__(self, attr: str) -> "AttrGetter[T]":
self._attrs.append(attr)
return self

async def async_resolve(self) -> typing.Any: # noqa: ANN401
return getattr(await self._provider.async_resolve(), self._attr_name)
resolved_provider_object = await self._provider.async_resolve()
attribute_path = ".".join(self._attrs)
return _get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path)

def sync_resolve(self) -> typing.Any: # noqa: ANN401
return getattr(self._provider.sync_resolve(), self._attr_name)
resolved_provider_object = self._provider.sync_resolve()
attribute_path = ".".join(self._attrs)
return _get_value_from_object_by_dotted_path(resolved_provider_object, attribute_path)