Skip to content

Commit c075e3f

Browse files
authored
Selector provider (#32)
* Selector provider
1 parent 26881a5 commit c075e3f

File tree

5 files changed

+89
-0
lines changed

5 files changed

+89
-0
lines changed

docs/providers/selector.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Selector
2+
3+
- Selector provider chooses between a provider based on a key.
4+
- Resolves into a single dependency.
5+
6+
```python
7+
import os
8+
from typing import Protocol
9+
from that_depends import BaseContainer, providers
10+
11+
class StorageService(Protocol):
12+
...
13+
14+
class StorageServiceLocal(StorageService):
15+
...
16+
17+
class StorageServiceRemote(StorageService):
18+
...
19+
20+
class DIContainer(BaseContainer):
21+
storage_service = providers.Selector(
22+
lambda: os.getenv("STORAGE_BACKEND", "local"),
23+
local=providers.Factory(StorageServiceLocal),
24+
remote=providers.Factory(StorageServiceRemote),
25+
)
26+
```

tests/container.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
logger = logging.getLogger(__name__)
1010

11+
global_state_for_selector: typing.Literal["sync_resource", "async_resource", "missing"] = "sync_resource"
12+
1113

1214
def create_sync_resource() -> typing.Iterator[datetime.datetime]:
1315
logger.debug("Resource initiated")
@@ -57,6 +59,11 @@ class DIContainer(BaseContainer):
5759
sync_resource = providers.Resource(create_sync_resource)
5860
async_resource = providers.AsyncResource(create_async_resource)
5961
sequence = providers.List(sync_resource, async_resource)
62+
selector: providers.Selector[datetime.datetime] = providers.Selector(
63+
lambda: global_state_for_selector,
64+
sync_resource=sync_resource,
65+
async_resource=async_resource,
66+
)
6067

6168
simple_factory = providers.Factory(SimpleFactory, dep1="text", dep2=123)
6269
async_factory = providers.AsyncFactory(async_factory, async_resource.cast)

tests/test_main_providers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,34 @@ async def test_list_provider() -> None:
5151
assert sequence == [sync_resource, async_resource]
5252

5353

54+
async def test_selector_provider_async() -> None:
55+
container.global_state_for_selector = "async_resource"
56+
selected = await DIContainer.selector()
57+
async_resource = await DIContainer.async_resource()
58+
59+
assert selected == async_resource
60+
61+
62+
async def test_selector_provider_async_missing() -> None:
63+
container.global_state_for_selector = "missing"
64+
with pytest.raises(RuntimeError):
65+
await DIContainer.selector()
66+
67+
68+
async def test_selector_provider_sync() -> None:
69+
container.global_state_for_selector = "sync_resource"
70+
selected = DIContainer.selector.sync_resolve()
71+
sync_resource = DIContainer.sync_resource.sync_resolve()
72+
73+
assert selected == sync_resource
74+
75+
76+
async def test_selector_provider_sync_missing() -> None:
77+
container.global_state_for_selector = "missing"
78+
with pytest.raises(RuntimeError):
79+
DIContainer.selector.sync_resolve()
80+
81+
5482
async def test_singleton_provider() -> None:
5583
singleton1 = await DIContainer.singleton()
5684
singleton2 = await DIContainer.singleton()

that_depends/providers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99
from that_depends.providers.factories import AsyncFactory, Factory
1010
from that_depends.providers.resources import AsyncResource, Resource
11+
from that_depends.providers.selector import Selector
1112
from that_depends.providers.singleton import Singleton
1213

1314

@@ -22,6 +23,7 @@
2223
"Factory",
2324
"List",
2425
"Resource",
26+
"Selector",
2527
"Singleton",
2628
"container_context",
2729
]

that_depends/providers/selector.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import typing
2+
3+
from that_depends.providers.base import AbstractProvider
4+
5+
6+
T = typing.TypeVar("T")
7+
8+
9+
class Selector(AbstractProvider[T]):
10+
def __init__(self, selector: typing.Callable[[], str], **providers: AbstractProvider[T]) -> None:
11+
self._selector = selector
12+
self._providers = providers
13+
14+
async def async_resolve(self) -> T:
15+
selected_key = self._selector()
16+
if selected_key not in self._providers:
17+
msg = f"No provider matches {selected_key}"
18+
raise RuntimeError(msg)
19+
return await self._providers[selected_key].async_resolve()
20+
21+
def sync_resolve(self) -> T:
22+
selected_key = self._selector()
23+
if selected_key not in self._providers:
24+
msg = f"No provider matches {selected_key}"
25+
raise RuntimeError(msg)
26+
return self._providers[selected_key].sync_resolve()

0 commit comments

Comments
 (0)