Skip to content

Commit bb5663e

Browse files
committed
Fix setter
1 parent c9b9354 commit bb5663e

File tree

3 files changed

+64
-8
lines changed

3 files changed

+64
-8
lines changed

pynamodb_mypy/plugin.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[mypy.plugin.Fun
1515
return _attribute_instantiation_hook
1616
return None
1717

18-
def get_attribute_hook(self, fullname: str
19-
) -> Optional[Callable[[mypy.plugin.AttributeContext], mypy.types.Type]]:
20-
sym = self.lookup_fully_qualified(fullname)
21-
if sym and sym.type and _is_attribute_marked_nullable(sym.type):
22-
return lambda ctx: mypy.types.UnionType([ctx.default_attr_type, mypy.types.NoneType()])
18+
def get_method_signature_hook(self, fullname: str
19+
) -> Optional[Callable[[mypy.plugin.MethodSigContext], mypy.types.CallableType]]:
20+
class_name, method_name = fullname.rsplit('.', 1)
21+
sym = self.lookup_fully_qualified(class_name)
22+
if sym and _is_attribute_type_node(sym.node):
23+
if method_name == '__get__':
24+
return _get_method_sig_hook
25+
elif method_name == '__set__':
26+
return _set_method_sig_hook
2327
return None
2428

2529

@@ -48,6 +52,48 @@ def _get_bool_literal(n: mypy.nodes.Node) -> Optional[bool]:
4852
}.get(n.fullname or '') if isinstance(n, mypy.nodes.NameExpr) else None
4953

5054

55+
def _make_optional(t: mypy.types.Type) -> mypy.types.UnionType:
56+
return mypy.types.UnionType([t, mypy.types.NoneType()])
57+
58+
59+
def _unwrap_optional(t: mypy.types.Type) -> mypy.types.Type:
60+
if not isinstance(t, mypy.types.UnionType):
61+
return t
62+
t = mypy.types.UnionType([item for item in t.items if not isinstance(item, mypy.types.NoneType)])
63+
if len(t.items) == 0:
64+
return mypy.types.NoneType()
65+
elif len(t.items) == 1:
66+
return t.items[0]
67+
else:
68+
return t
69+
70+
71+
def _get_method_sig_hook(ctx: mypy.plugin.MethodSigContext) -> mypy.types.CallableType:
72+
sig = ctx.default_signature
73+
if not _is_attribute_marked_nullable(ctx.type):
74+
return sig
75+
try:
76+
(instance_type, owner_type) = sig.arg_types
77+
except ValueError:
78+
return sig
79+
if not isinstance(instance_type, mypy.types.AnyType): # instance attribute access
80+
return sig
81+
return sig.copy_modified(ret_type=_make_optional(sig.ret_type))
82+
83+
84+
def _set_method_sig_hook(ctx: mypy.plugin.MethodSigContext) -> mypy.types.CallableType:
85+
sig = ctx.default_signature
86+
if _is_attribute_marked_nullable(ctx.type):
87+
return sig
88+
try:
89+
(instance_type, value_type) = sig.arg_types
90+
except ValueError:
91+
return sig
92+
if not isinstance(instance_type, mypy.types.AnyType): # instance attribute access
93+
return sig
94+
return sig.copy_modified(arg_types=[instance_type, _unwrap_optional(value_type)])
95+
96+
5197
def _attribute_instantiation_hook(ctx: mypy.plugin.FunctionContext) -> mypy.types.Type:
5298
"""
5399
Handles attribute instantiation, e.g. MyAttribute(null=True)

tests/mypy_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def _run_mypy(program: str, *, use_pdb: bool) -> Iterable[str]:
2424
mypy_args = [
2525
f.name,
2626
'--show-traceback',
27+
'--raise-exceptions',
28+
'--show-error-codes',
2729
'--config-file', config_file,
2830
]
2931
if use_pdb:

tests/test_plugin.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ class MyModel(Model):
1212
reveal_type(MyModel().my_attr) # N: Revealed type is 'builtins.float*'
1313
reveal_type(MyModel().my_nullable_attr) # N: Revealed type is 'Union[builtins.float*, None]'
1414
reveal_type(MyModel().my_not_nullable_attr) # N: Revealed type is 'builtins.float*'
15+
16+
my_model = MyModel()
17+
my_model.my_attr = None
18+
my_model.my_nullable_attr = None
19+
my_model.my_not_nullable_attr = None
20+
my_model.my_attr = 42
21+
my_model.my_nullable_attr = 42
22+
my_model.my_not_nullable_attr = 42
1523
""")
1624

1725

@@ -30,7 +38,7 @@ class MyModel(Model):
3038
reveal_type(MyModel().my_nullable_attr) # N: Revealed type is 'Union[builtins.str*, None]'
3139
3240
MyModel().my_attr.lower()
33-
MyModel().my_nullable_attr.lower() # E: Item "None" of "Optional[str]" has no attribute "lower"
41+
MyModel().my_nullable_attr.lower() # E: Item "None" of "Optional[str]" has no attribute "lower" [union-attr]
3442
""")
3543

3644

@@ -64,7 +72,7 @@ def test_unexpected_number_of_nulls(assert_mypy_output):
6472
from pynamodb.models import Model
6573
6674
class MyModel(Model):
67-
my_attr = NumberAttribute(True, True, True, null=True) # E: "NumberAttribute" gets multiple values for keyword argument "null"
75+
my_attr = NumberAttribute(True, True, True, null=True) # E: "NumberAttribute" gets multiple values for keyword argument "null" [misc]
6876
6977
reveal_type(MyModel().my_attr) # N: Revealed type is 'builtins.float*'
7078
""") # noqa: E501
@@ -76,7 +84,7 @@ def test_unexpected_value_of_null(assert_mypy_output):
7684
from pynamodb.models import Model
7785
7886
class MyModel(Model):
79-
my_attr = NumberAttribute(null=bool(5)) # E: 'null' argument is not constant False or True, cannot deduce optionality
87+
my_attr = NumberAttribute(null=bool(5)) # E: 'null' argument is not constant False or True, cannot deduce optionality [misc]
8088
8189
reveal_type(MyModel().my_attr) # N: Revealed type is 'builtins.float*'
8290
""") # noqa: E501

0 commit comments

Comments
 (0)