Skip to content

Commit 561e818

Browse files
authored
Merge pull request ivankorobkov#26 from cselvaraj/fix/autoparams-keyword-only
Make autoparams work with keyword-only parameters.
2 parents b00a753 + 8048792 commit 561e818

File tree

3 files changed

+41
-61
lines changed

3 files changed

+41
-61
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import sys
21
from distutils.core import setup
2+
import sys
33

44

55
def read_description():
@@ -9,7 +9,7 @@ def read_description():
99

1010
setup(
1111
name='Inject',
12-
version='3.5.0',
12+
version='3.5.1dev0',
1313
url='https://github.com/ivankorobkov/python-inject',
1414
license='Apache License 2.0',
1515

src/inject.py

Lines changed: 28 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def my_config(binder):
77
binder.bind(Cache, RedisCache('localhost:1234'))
88
binder.bind_to_provider(CurrentUser, get_current_user)
9-
9+
1010
- Create a shared injector::
1111
inject.configure(my_config)
1212
@@ -45,49 +45,50 @@ def bar(foo, cache=None):
4545
and `inject.clear()` to clean-up on tear down.
4646
4747
Runtime bindings greatly reduce the required configuration by automatically creating singletons
48-
on first access. For example, below only the Config class requires binding configuration,
48+
on first access. For example, below only the Config class requires binding configuration,
4949
all other classes are runtime bindings::
5050
class Cache(object):
5151
config = inject.attr(Config)
52-
52+
5353
def __init__(self):
5454
self._redis = connect(self.config.redis_address)
55-
55+
5656
class Db(object):
5757
pass
58-
58+
5959
class UserRepo(object):
6060
cache = inject.attr(Cache)
6161
db = inject.attr(Db)
62-
62+
6363
def load(self, user_id):
6464
return cache.load('user', user_id) or db.load('user', user_id)
65-
65+
6666
class Config(object):
6767
def __init__(self, redis_address):
6868
self.redis_address = redis_address
69-
69+
7070
def my_config(binder):
7171
binder.bind(Config, load_config_file())
72-
72+
7373
inject.configure(my_config)
7474
7575
"""
76-
__version__ = '3.3.2'
76+
__version__ = '3.5.1dev0'
7777
__author__ = 'Ivan Korobkov <ivan.korobkov@gmail.com>'
7878
__license__ = 'Apache License 2.0'
7979
__url__ = 'https://github.com/ivan-korobkov/python-inject'
8080

81-
import logging
82-
import threading
81+
from functools import wraps
8382
import inspect
83+
import logging
8484
import sys
85-
from functools import wraps
85+
import threading
86+
8687

8788
logger = logging.getLogger('inject')
8889

8990
_INJECTOR = None # Shared injector instance.
90-
_INJECTOR_LOCK = threading.RLock() # Guards injector initialization.
91+
_INJECTOR_LOCK = threading.RLock() # Guards injector initialization.
9192
_BINDING_LOCK = threading.RLock() # Guards runtime bindings.
9293

9394

@@ -155,9 +156,9 @@ def param(name, cls=None):
155156

156157
def params(**args_to_classes):
157158
"""Return a decorator which injects args into a function.
158-
159+
159160
For example::
160-
161+
161162
@inject.params(cache=RedisCache, db=DbInterface)
162163
def sign_up(name, email, cache, db):
163164
pass
@@ -188,14 +189,12 @@ def autoparams_decorator(func):
188189

189190
full_args_spec = inspect.getfullargspec(func)
190191
annotations_items = full_args_spec.annotations.items()
192+
all_arg_names = frozenset(full_args_spec.args + full_args_spec.kwonlyargs)
193+
args_to_check = frozenset(selected_args) or all_arg_names
191194
args_annotated_types = {
192195
arg_name: annotated_type for arg_name, annotated_type in annotations_items
193-
if arg_name in full_args_spec.args
196+
if arg_name in args_to_check
194197
}
195-
if selected_args:
196-
keys_to_remove = set(args_annotated_types.keys()) - set(selected_args)
197-
for key in keys_to_remove:
198-
del args_annotated_types[key]
199198
return _ParametersInjection(**args_annotated_types)(func)
200199

201200
return autoparams_decorator
@@ -337,7 +336,7 @@ def __init__(self, name, cls=None):
337336
def __call__(self, func):
338337
@wraps(func)
339338
def injection_wrapper(*args, **kwargs):
340-
if not self._name in kwargs:
339+
if self._name not in kwargs:
341340
kwargs[self._name] = instance(self._cls or self._name)
342341
return func(*args, **kwargs)
343342

@@ -355,35 +354,16 @@ def __call__(self, func):
355354
arg_names = inspect.getargspec(func).args
356355
else:
357356
arg_names = inspect.getfullargspec(func).args
358-
params = self._params
357+
params_to_provide = self._params
359358

360359
@wraps(func)
361360
def injection_wrapper(*args, **kwargs):
362-
# arguments injected
363-
additional_args = []
364-
365-
# iterate over the positional arguments of the function definition
366-
i = len(args)
367-
while i < len(arg_names):
368-
arg_name = arg_names[i]
369-
370-
# stop when we do not have a parameter for the positional argument
371-
# or stop at the first keyword argument
372-
if arg_name not in params or arg_name in kwargs:
373-
break
374-
375-
# this parameter will be injected into the *args
376-
additional_args.append(instance(params[arg_name]))
377-
i += 1
378-
379-
if additional_args:
380-
args += tuple(additional_args)
381-
382-
# a list of all positional args that we have injected
383-
used_args = arg_names[:i]
384-
for name, cls in params.items():
385-
if not name in kwargs and not name in used_args:
386-
kwargs[name] = instance(cls)
361+
362+
provided_params = frozenset(arg_names[:len(args)]) | frozenset(kwargs.keys())
363+
for param, cls in params_to_provide.items():
364+
if param not in provided_params:
365+
kwargs[param] = instance(cls)
366+
387367
return func(*args, **kwargs)
388368

389369
return injection_wrapper

src/test_inject/test_autoparams.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_func(val: int = None):
1616

1717
def test_autoparams_multi(self):
1818
@inject.autoparams()
19-
def test_func(a: 'A', b: 'B', c: 'C'):
19+
def test_func(a: 'A', b: 'B', *, c: 'C'):
2020
return a, b, c
2121

2222
def config(binder):
@@ -29,7 +29,7 @@ def config(binder):
2929
assert test_func() == (1, 2, 3)
3030
assert test_func(10) == (10, 2, 3)
3131
assert test_func(10, 20) == (10, 20, 3)
32-
assert test_func(10, 20, 30) == (10, 20, 30)
32+
assert test_func(10, 20, c=30) == (10, 20, 30)
3333
assert test_func(a='a') == ('a', 2, 3)
3434
assert test_func(b='b') == (1, 'b', 3)
3535
assert test_func(c='c') == (1, 2, 'c')
@@ -39,7 +39,7 @@ def config(binder):
3939

4040
def test_autoparams_with_defaults(self):
4141
@inject.autoparams()
42-
def test_func(a=1, b: 'B' = None, c: 'C' = 300):
42+
def test_func(a=1, b: 'B' = None, *, c: 'C' = 300):
4343
return a, b, c
4444

4545
def config(binder):
@@ -51,7 +51,7 @@ def config(binder):
5151
assert test_func() == (1, 2, 3)
5252
assert test_func(10) == (10, 2, 3)
5353
assert test_func(10, 20) == (10, 20, 3)
54-
assert test_func(10, 20, 30) == (10, 20, 30)
54+
assert test_func(10, 20, c=30) == (10, 20, 30)
5555
assert test_func(a='a') == ('a', 2, 3)
5656
assert test_func(b='b') == (1, 'b', 3)
5757
assert test_func(c='c') == (1, 2, 'c')
@@ -62,7 +62,7 @@ def config(binder):
6262
def test_autoparams_on_method(self):
6363
class Test:
6464
@inject.autoparams()
65-
def func(self, a=1, b: 'B' = None, c: 'C' = None):
65+
def func(self, a=1, b: 'B' = None, *, c: 'C' = None):
6666
return self, a, b, c
6767

6868
def config(binder):
@@ -75,7 +75,7 @@ def config(binder):
7575
assert test.func() == (test, 1, 2, 3)
7676
assert test.func(10) == (test, 10, 2, 3)
7777
assert test.func(10, 20) == (test, 10, 20, 3)
78-
assert test.func(10, 20, 30) == (test, 10, 20, 30)
78+
assert test.func(10, 20, c=30) == (test, 10, 20, 30)
7979
assert test.func(a='a') == (test, 'a', 2, 3)
8080
assert test.func(b='b') == (test, 1, 'b', 3)
8181
assert test.func(c='c') == (test, 1, 2, 'c')
@@ -88,7 +88,7 @@ class Test:
8888
# note inject must be *before* classmethod!
8989
@classmethod
9090
@inject.autoparams()
91-
def func(cls, a=1, b: 'B' = None, c: 'C' = None):
91+
def func(cls, a=1, b: 'B' = None, *, c: 'C' = None):
9292
return cls, a, b, c
9393

9494
def config(binder):
@@ -100,7 +100,7 @@ def config(binder):
100100
assert Test.func() == (Test, 1, 2, 3)
101101
assert Test.func(10) == (Test, 10, 2, 3)
102102
assert Test.func(10, 20) == (Test, 10, 20, 3)
103-
assert Test.func(10, 20, 30) == (Test, 10, 20, 30)
103+
assert Test.func(10, 20, c=30) == (Test, 10, 20, 30)
104104
assert Test.func(a='a') == (Test, 'a', 2, 3)
105105
assert Test.func(b='b') == (Test, 1, 'b', 3)
106106
assert Test.func(c='c') == (Test, 1, 2, 'c')
@@ -113,7 +113,7 @@ class Test:
113113
# note inject must be *before* classmethod!
114114
@classmethod
115115
@inject.autoparams()
116-
def func(cls, a=1, b: 'B' = None, c: 'C' = None):
116+
def func(cls, a=1, b: 'B' = None, *, c: 'C' = None):
117117
return cls, a, b, c
118118

119119
def config(binder):
@@ -126,7 +126,7 @@ def config(binder):
126126
assert test.func() == (Test, 1, 2, 3)
127127
assert test.func(10) == (Test, 10, 2, 3)
128128
assert test.func(10, 20) == (Test, 10, 20, 3)
129-
assert test.func(10, 20, 30) == (Test, 10, 20, 30)
129+
assert test.func(10, 20, c=30) == (Test, 10, 20, 30)
130130
assert test.func(a='a') == (Test, 'a', 2, 3)
131131
assert test.func(b='b') == (Test, 1, 'b', 3)
132132
assert test.func(c='c') == (Test, 1, 2, 'c')
@@ -136,7 +136,7 @@ def config(binder):
136136

137137
def test_autoparams_only_selected(self):
138138
@inject.autoparams('a', 'c')
139-
def test_func(a: 'A', b: 'B', c: 'C'):
139+
def test_func(a: 'A', b: 'B', *, c: 'C'):
140140
return a, b, c
141141

142142
def config(binder):

0 commit comments

Comments
 (0)