@@ -15,11 +15,15 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[mypy.plugin.Fun
15
15
return _attribute_instantiation_hook
16
16
return None
17
17
18
- def get_attribute_hook (self , fullname : str
19
- ) -> Optional [Callable [[mypy .plugin .AttributeContext ], mypy .types .Type ]]:
20
- sym = self .lookup_fully_qualified (fullname )
21
- if sym and sym .type and _is_attribute_marked_nullable (sym .type ):
22
- return lambda ctx : mypy .types .UnionType ([ctx .default_attr_type , mypy .types .NoneType ()])
18
+ def get_method_signature_hook (self , fullname : str
19
+ ) -> Optional [Callable [[mypy .plugin .MethodSigContext ], mypy .types .CallableType ]]:
20
+ class_name , method_name = fullname .rsplit ('.' , 1 )
21
+ sym = self .lookup_fully_qualified (class_name )
22
+ if sym and _is_attribute_type_node (sym .node ):
23
+ if method_name == '__get__' :
24
+ return _get_method_sig_hook
25
+ elif method_name == '__set__' :
26
+ return _set_method_sig_hook
23
27
return None
24
28
25
29
@@ -48,6 +52,48 @@ def _get_bool_literal(n: mypy.nodes.Node) -> Optional[bool]:
48
52
}.get (n .fullname or '' ) if isinstance (n , mypy .nodes .NameExpr ) else None
49
53
50
54
55
+ def _make_optional (t : mypy .types .Type ) -> mypy .types .UnionType :
56
+ return mypy .types .UnionType ([t , mypy .types .NoneType ()])
57
+
58
+
59
+ def _unwrap_optional (t : mypy .types .Type ) -> mypy .types .Type :
60
+ if not isinstance (t , mypy .types .UnionType ):
61
+ return t
62
+ t = mypy .types .UnionType ([item for item in t .items if not isinstance (item , mypy .types .NoneType )])
63
+ if len (t .items ) == 0 :
64
+ return mypy .types .NoneType ()
65
+ elif len (t .items ) == 1 :
66
+ return t .items [0 ]
67
+ else :
68
+ return t
69
+
70
+
71
+ def _get_method_sig_hook (ctx : mypy .plugin .MethodSigContext ) -> mypy .types .CallableType :
72
+ sig = ctx .default_signature
73
+ if not _is_attribute_marked_nullable (ctx .type ):
74
+ return sig
75
+ try :
76
+ (instance_type , owner_type ) = sig .arg_types
77
+ except ValueError :
78
+ return sig
79
+ if not isinstance (instance_type , mypy .types .AnyType ): # instance attribute access
80
+ return sig
81
+ return sig .copy_modified (ret_type = _make_optional (sig .ret_type ))
82
+
83
+
84
+ def _set_method_sig_hook (ctx : mypy .plugin .MethodSigContext ) -> mypy .types .CallableType :
85
+ sig = ctx .default_signature
86
+ if _is_attribute_marked_nullable (ctx .type ):
87
+ return sig
88
+ try :
89
+ (instance_type , value_type ) = sig .arg_types
90
+ except ValueError :
91
+ return sig
92
+ if not isinstance (instance_type , mypy .types .AnyType ): # instance attribute access
93
+ return sig
94
+ return sig .copy_modified (arg_types = [instance_type , _unwrap_optional (value_type )])
95
+
96
+
51
97
def _attribute_instantiation_hook (ctx : mypy .plugin .FunctionContext ) -> mypy .types .Type :
52
98
"""
53
99
Handles attribute instantiation, e.g. MyAttribute(null=True)
0 commit comments