From 78bb6217da86a8d3b45087cd599ec4c5a46ebd27 Mon Sep 17 00:00:00 2001 From: ljnsn <82611987+ljnsn@users.noreply.github.com> Date: Thu, 1 Feb 2024 00:47:57 +0100 Subject: [PATCH] Add support for Annotated types (#219) This is a relatively simple and straightforward implementation with minimal changes and seems to work fine. Resolves #217 --- injector/__init__.py | 2 ++ injector_test.py | 72 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/injector/__init__.py b/injector/__init__.py index 4136f8f..b9a91bf 100644 --- a/injector/__init__.py +++ b/injector/__init__.py @@ -705,6 +705,8 @@ def _punch_through_alias(type_: Any) -> type: and type(type_).__name__ == 'NewType' ): return type_.__supertype__ + elif isinstance(type_, _AnnotatedAlias) and getattr(type_, '__metadata__', None) is not None: + return type_.__origin__ else: return type_ diff --git a/injector_test.py b/injector_test.py index 10087f2..3d98254 100644 --- a/injector_test.py +++ b/injector_test.py @@ -18,6 +18,11 @@ import traceback import warnings +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + from typing import Dict, List, NewType import pytest @@ -1682,3 +1687,70 @@ def function1(a: int | str) -> None: pass assert get_bindings(function1) == {'a': Union[int, str]} + + +# test for https://github.com/python-injector/injector/issues/217 +def test_annotated_instance_integration_works(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + + injector = Injector([configure]) + assert injector.get(UserID) == 123 + + +def test_annotated_class_integration_works(): + class Shape(abc.ABC): + pass + + class Circle(Shape): + pass + + first = Annotated[Shape, "first"] + + def configure(binder): + binder.bind(first, to=Circle) + + injector = Injector([configure]) + assert isinstance(injector.get(first), Circle) + + +def test_annotated_meta_separate_bindings(): + first = Annotated[int, "first"] + second = Annotated[int, "second"] + + def configure(binder): + binder.bind(first, to=123) + binder.bind(second, to=456) + + injector = Injector([configure]) + assert injector.get(first) == 123 + assert injector.get(second) == 456 + assert injector.get(first) != injector.get(second) + + +def test_annotated_origin_separate_bindings(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + binder.bind(int, to=456) + + injector = Injector([configure]) + assert injector.get(UserID) == 123 + assert injector.get(int) == 456 + assert injector.get(UserID) != injector.get(int) + + +def test_annotated_non_comparable_types(): + foo = Annotated[int, float("nan")] + bar = Annotated[int, object()] + + def configure(binder): + binder.bind(foo, to=123) + binder.bind(bar, to=456) + + injector = Injector([configure]) + assert injector.get(foo) == 123 + assert injector.get(bar) == 456