-
-
Notifications
You must be signed in to change notification settings - Fork 3k
Allow unpacking of TypedDict into TypedDict #13353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
2c3850a
0d32e60
aaa4b20
9dda0f6
e755d0e
74c31eb
4f89b41
63e4ca8
c4af4fd
df33b86
c96fd8a
4b9a739
114e12a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -617,13 +617,11 @@ def check_typeddict_call( | |
args: List[Expression], | ||
context: Context, | ||
) -> Type: | ||
if len(args) >= 1 and all([ak == ARG_NAMED for ak in arg_kinds]): | ||
# ex: Point(x=42, y=1337) | ||
assert all(arg_name is not None for arg_name in arg_names) | ||
item_names = cast(List[str], arg_names) | ||
if all([ak in {ARG_NAMED, ARG_STAR2} for ak in arg_kinds]): | ||
# ex: Point(x=42, y=1337, **other_point) | ||
item_args = args | ||
return self.check_typeddict_call_with_kwargs( | ||
callee, dict(zip(item_names, item_args)), context | ||
callee, list(zip(arg_names, item_args)), context | ||
) | ||
|
||
if len(args) == 1 and arg_kinds[0] == ARG_POS: | ||
|
@@ -635,44 +633,50 @@ def check_typeddict_call( | |
# ex: Point(dict(x=42, y=1337)) | ||
return self.check_typeddict_call_with_dict(callee, unique_arg.analyzed, context) | ||
|
||
if len(args) == 0: | ||
# ex: EmptyDict() | ||
return self.check_typeddict_call_with_kwargs(callee, {}, context) | ||
|
||
self.chk.fail(message_registry.INVALID_TYPEDDICT_ARGS, context) | ||
return AnyType(TypeOfAny.from_error) | ||
|
||
def validate_typeddict_kwargs(self, kwargs: DictExpr) -> "Optional[Dict[str, Expression]]": | ||
item_args = [item[1] for item in kwargs.items] | ||
|
||
item_names = [] # List[str] | ||
def validate_typeddict_kwargs( | ||
self, kwargs: DictExpr | ||
) -> Optional[List[Tuple[Optional[str], Expression]]]: | ||
"""Validate kwargs for TypedDict constructor, e.g. Point({'x': 1, 'y': 2}). | ||
Check that all items have string literal keys or are using unpack operator (**) | ||
""" | ||
items: List[Tuple[Optional[str], Expression]] = [] | ||
for item_name_expr, item_arg in kwargs.items: | ||
# If unpack operator (**) was used, name will be None | ||
if item_name_expr is None: | ||
items.append((None, item_arg)) | ||
continue | ||
literal_value = None | ||
if item_name_expr: | ||
key_type = self.accept(item_name_expr) | ||
values = try_getting_str_literals(item_name_expr, key_type) | ||
if values and len(values) == 1: | ||
literal_value = values[0] | ||
key_type = self.accept(item_name_expr) | ||
values = try_getting_str_literals(item_name_expr, key_type) | ||
if values and len(values) == 1: | ||
literal_value = values[0] | ||
if literal_value is None: | ||
key_context = item_name_expr or item_arg | ||
self.chk.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, key_context) | ||
return None | ||
else: | ||
item_names.append(literal_value) | ||
return dict(zip(item_names, item_args)) | ||
items.append((literal_value, item_arg)) | ||
return items | ||
|
||
def match_typeddict_call_with_dict( | ||
self, callee: TypedDictType, kwargs: DictExpr, context: Context | ||
) -> bool: | ||
def match_typeddict_call_with_dict(self, callee: TypedDictType, kwargs: DictExpr) -> bool: | ||
"""Check that kwargs is valid set of TypedDict items, contains all required keys of callee, and has no extraneous keys""" | ||
validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs) | ||
if validated_kwargs is not None: | ||
return callee.required_keys <= set(validated_kwargs.keys()) <= set(callee.items.keys()) | ||
return ( | ||
callee.required_keys | ||
<= set(dict(validated_kwargs).keys()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that >>> set() <= {'a': 1}.keys()
True There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right! Interesting to see that |
||
<= set(callee.items.keys()) | ||
) | ||
else: | ||
return False | ||
|
||
def check_typeddict_call_with_dict( | ||
self, callee: TypedDictType, kwargs: DictExpr, context: Context | ||
) -> Type: | ||
"""Check TypedDict constructor of format Point({'x': 1, 'y': 2})""" | ||
validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs) | ||
if validated_kwargs is not None: | ||
return self.check_typeddict_call_with_kwargs( | ||
|
@@ -682,30 +686,67 @@ def check_typeddict_call_with_dict( | |
return AnyType(TypeOfAny.from_error) | ||
|
||
def check_typeddict_call_with_kwargs( | ||
self, callee: TypedDictType, kwargs: Dict[str, Expression], context: Context | ||
self, | ||
callee: TypedDictType, | ||
kwargs: List[Tuple[Optional[str], Expression]], | ||
context: Context, | ||
) -> Type: | ||
if not (callee.required_keys <= set(kwargs.keys()) <= set(callee.items.keys())): | ||
expected_keys = [ | ||
key | ||
for key in callee.items.keys() | ||
if key in callee.required_keys or key in kwargs.keys() | ||
] | ||
actual_keys = kwargs.keys() | ||
"""Check TypedDict constructor of format Point(x=1, y=2)""" | ||
# Infer types of item values and expand unpack operators | ||
items: Dict[str, Tuple[Expression, Type]] = {} | ||
sure_keys: List[str] = [] | ||
maybe_keys: List[str] = [] # Will contain non-required items of unpacked TypedDicts | ||
for key, value_expr in kwargs: | ||
if key is not None: | ||
# Regular key and value | ||
value_type = self.accept(value_expr, callee.items.get(key)) | ||
items[key] = (value_expr, value_type) | ||
sure_keys.append(key) | ||
else: | ||
# Unpack operator (**) was used; unpack all items of the type of this expression into items list | ||
value_type = self.accept(value_expr, callee) | ||
proper_type = get_proper_type(value_type) | ||
if isinstance(proper_type, TypedDictType): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will not work for unions of TypedDicts (or other more complex types). I'm open to suggestions to improve this |
||
for nested_key, nested_value_type in proper_type.items.items(): | ||
items[nested_key] = (value_expr, nested_value_type) | ||
if nested_key in proper_type.required_keys: | ||
sure_keys.append(nested_key) | ||
else: | ||
maybe_keys.append(nested_key) | ||
else: | ||
# Fail when trying to unpack anything but TypedDict | ||
assert not self.chk.check_subtype( | ||
subtype=value_type, | ||
supertype=self.chk.named_type("typing._TypedDict"), | ||
context=value_expr, | ||
msg=message_registry.INCOMPATIBLE_TYPES, | ||
subtype_label="unpacked expression has type", | ||
supertype_label="expected", | ||
code=codes.TYPEDDICT_ITEM, | ||
) | ||
|
||
if not ( | ||
callee.required_keys | ||
<= set(sure_keys) | ||
<= set(sure_keys + maybe_keys) | ||
<= set(callee.items.keys()) | ||
): | ||
self.msg.unexpected_typeddict_keys( | ||
callee, expected_keys=expected_keys, actual_keys=list(actual_keys), context=context | ||
callee, actual_sure_keys=sure_keys, actual_maybe_keys=maybe_keys, context=context | ||
) | ||
return AnyType(TypeOfAny.from_error) | ||
|
||
# Check item value types | ||
for (item_name, item_expected_type) in callee.items.items(): | ||
if item_name in kwargs: | ||
item_value = kwargs[item_name] | ||
self.chk.check_simple_assignment( | ||
lvalue_type=item_expected_type, | ||
rvalue=item_value, | ||
context=item_value, | ||
if item_name in items: | ||
item_value_expr, item_actual_type = items[item_name] | ||
self.chk.check_subtype( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm calling |
||
subtype=item_actual_type, | ||
supertype=item_expected_type, | ||
context=item_value_expr, | ||
msg=message_registry.INCOMPATIBLE_TYPES, | ||
lvalue_name=f'TypedDict item "{item_name}"', | ||
rvalue_name="expression", | ||
subtype_label="expression has type", | ||
supertype_label=f'TypedDict item "{item_name}" has type', | ||
code=codes.TYPEDDICT_ITEM, | ||
) | ||
|
||
|
@@ -3997,7 +4038,7 @@ def find_typeddict_context( | |
for item in context.items: | ||
item_context = self.find_typeddict_context(item, dict_expr) | ||
if item_context is not None and self.match_typeddict_call_with_dict( | ||
item_context, dict_expr, dict_expr | ||
item_context, dict_expr | ||
): | ||
items.append(item_context) | ||
if len(items) == 1: | ||
|
Uh oh!
There was an error while loading. Please reload this page.