Skip to content

Commit

Permalink
Add ReadOnly support for TypedDicts (#17644)
Browse files Browse the repository at this point in the history
Refs #17264

I will add docs in a separate PR.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
sobolevn and pre-commit-ci[bot] authored Sep 30, 2024
1 parent 1a2c8e2 commit 6726d77
Show file tree
Hide file tree
Showing 25 changed files with 652 additions and 74 deletions.
24 changes: 24 additions & 0 deletions docs/source/error_code_list.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,30 @@ If the code being checked is not syntactically valid, mypy issues a
syntax error. Most, but not all, syntax errors are *blocking errors*:
they can't be ignored with a ``# type: ignore`` comment.

.. _code-typeddict-readonly-mutated:

ReadOnly key of a TypedDict is mutated [typeddict-readonly-mutated]
-------------------------------------------------------------------

Consider this example:

.. code-block:: python
from datetime import datetime
from typing import TypedDict
from typing_extensions import ReadOnly
class User(TypedDict):
username: ReadOnly[str]
last_active: datetime
user: User = {'username': 'foobar', 'last_active': datetime.now()}
user['last_active'] = datetime.now() # ok
user['username'] = 'other' # error: ReadOnly TypedDict key "key" TypedDict is mutated [typeddict-readonly-mutated]
`PEP 705 <https://peps.python.org/pep-0705>`_ specifies
how ``ReadOnly`` special form works for ``TypedDict`` objects.

.. _code-misc:

Miscellaneous checks [misc]
Expand Down
14 changes: 9 additions & 5 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,10 @@ def check_typeddict_call_with_kwargs(
always_present_keys: set[str],
) -> Type:
actual_keys = kwargs.keys()
if callee.to_be_mutated:
assigned_readonly_keys = actual_keys & callee.readonly_keys
if assigned_readonly_keys:
self.msg.readonly_keys_mutated(assigned_readonly_keys, context=context)
if not (
callee.required_keys <= always_present_keys and actual_keys <= callee.items.keys()
):
Expand Down Expand Up @@ -4349,7 +4353,7 @@ def visit_index_with_type(
else:
return self.nonliteral_tuple_index_helper(left_type, index)
elif isinstance(left_type, TypedDictType):
return self.visit_typeddict_index_expr(left_type, e.index)
return self.visit_typeddict_index_expr(left_type, e.index)[0]
elif isinstance(left_type, FunctionLike) and left_type.is_type_obj():
if left_type.type_object().is_enum:
return self.visit_enum_index_expr(left_type.type_object(), e.index, e)
Expand Down Expand Up @@ -4530,7 +4534,7 @@ def union_tuple_fallback_item(self, left_type: TupleType) -> Type:

def visit_typeddict_index_expr(
self, td_type: TypedDictType, index: Expression, setitem: bool = False
) -> Type:
) -> tuple[Type, set[str]]:
if isinstance(index, StrExpr):
key_names = [index.value]
else:
Expand All @@ -4553,17 +4557,17 @@ def visit_typeddict_index_expr(
key_names.append(key_type.value)
else:
self.msg.typeddict_key_must_be_string_literal(td_type, index)
return AnyType(TypeOfAny.from_error)
return AnyType(TypeOfAny.from_error), set()

value_types = []
for key_name in key_names:
value_type = td_type.items.get(key_name)
if value_type is None:
self.msg.typeddict_key_not_found(td_type, key_name, index, setitem)
return AnyType(TypeOfAny.from_error)
return AnyType(TypeOfAny.from_error), set()
else:
value_types.append(value_type)
return make_simplified_union(value_types)
return make_simplified_union(value_types), set(key_names)

def visit_enum_index_expr(
self, enum_type: TypeInfo, index: Expression, context: Context
Expand Down
5 changes: 4 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,9 +1185,12 @@ def analyze_typeddict_access(
if isinstance(mx.context, IndexExpr):
# Since we can get this during `a['key'] = ...`
# it is safe to assume that the context is `IndexExpr`.
item_type = mx.chk.expr_checker.visit_typeddict_index_expr(
item_type, key_names = mx.chk.expr_checker.visit_typeddict_index_expr(
typ, mx.context.index, setitem=True
)
assigned_readonly_keys = typ.readonly_keys & key_names
if assigned_readonly_keys:
mx.msg.readonly_keys_mutated(assigned_readonly_keys, context=mx.context)
else:
# It can also be `a.__setitem__(...)` direct call.
# In this case `item_type` can be `Any`,
Expand Down
2 changes: 1 addition & 1 deletion mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def get_mapping_item_type(
with self.msg.filter_errors() as local_errors:
result: Type | None = self.chk.expr_checker.visit_typeddict_index_expr(
mapping_type, key
)
)[0]
has_local_errors = local_errors.has_new_errors()
# If we can't determine the type statically fall back to treating it as a normal
# mapping
Expand Down
4 changes: 3 additions & 1 deletion mypy/copytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def visit_tuple_type(self, t: TupleType) -> ProperType:
return self.copy_common(t, TupleType(t.items, t.partial_fallback, implicit=t.implicit))

def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
return self.copy_common(t, TypedDictType(t.items, t.required_keys, t.fallback))
return self.copy_common(
t, TypedDictType(t.items, t.required_keys, t.readonly_keys, t.fallback)
)

def visit_literal_type(self, t: LiteralType) -> ProperType:
return self.copy_common(t, LiteralType(value=t.value, fallback=t.fallback))
Expand Down
3 changes: 3 additions & 0 deletions mypy/errorcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def __hash__(self) -> int:
ANNOTATION_UNCHECKED = ErrorCode(
"annotation-unchecked", "Notify about type annotations in unchecked functions", "General"
)
TYPEDDICT_READONLY_MUTATED = ErrorCode(
"typeddict-readonly-mutated", "TypedDict's ReadOnly key is mutated", "General"
)
POSSIBLY_UNDEFINED: Final[ErrorCode] = ErrorCode(
"possibly-undefined",
"Warn about variables that are defined only in some execution paths",
Expand Down
2 changes: 1 addition & 1 deletion mypy/exprtotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def expr_to_unanalyzed_type(
value, options, allow_new_syntax, expr
)
result = TypedDictType(
items, set(), Instance(MISSING_FALLBACK, ()), expr.line, expr.column
items, set(), set(), Instance(MISSING_FALLBACK, ()), expr.line, expr.column
)
result.extra_items_from = extra_items_from
return result
Expand Down
2 changes: 1 addition & 1 deletion mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2130,7 +2130,7 @@ def visit_Dict(self, n: ast3.Dict) -> Type:
continue
return self.invalid_type(n)
items[item_name.value] = self.visit(value)
result = TypedDictType(items, set(), _dummy_fallback, n.lineno, n.col_offset)
result = TypedDictType(items, set(), set(), _dummy_fallback, n.lineno, n.col_offset)
result.extra_items_from = extra_items_from
return result

Expand Down
7 changes: 5 additions & 2 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,10 +631,13 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
)
}
fallback = self.s.create_anonymous_fallback()
all_keys = set(items.keys())
# We need to filter by items.keys() since some required keys present in both t and
# self.s might be missing from the join if the types are incompatible.
required_keys = set(items.keys()) & t.required_keys & self.s.required_keys
return TypedDictType(items, required_keys, fallback)
required_keys = all_keys & t.required_keys & self.s.required_keys
# If one type has a key as readonly, we mark it as readonly for both:
readonly_keys = (t.readonly_keys | t.readonly_keys) & all_keys
return TypedDictType(items, required_keys, readonly_keys, fallback)
elif isinstance(self.s, Instance):
return join_types(self.s, t.fallback)
else:
Expand Down
12 changes: 11 additions & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
items = dict(item_list)
fallback = self.s.create_anonymous_fallback()
required_keys = t.required_keys | self.s.required_keys
return TypedDictType(items, required_keys, fallback)
readonly_keys = t.readonly_keys | self.s.readonly_keys
return TypedDictType(items, required_keys, readonly_keys, fallback)
elif isinstance(self.s, Instance) and is_subtype(t, self.s):
return t
else:
Expand Down Expand Up @@ -1139,6 +1140,9 @@ def typed_dict_mapping_overlap(
- TypedDict(x=str, y=str, total=False) doesn't overlap with Dict[str, int]
- TypedDict(x=int, y=str, total=False) overlaps with Dict[str, str]
* A TypedDict with at least one ReadOnly[] key does not overlap
with Dict or MutableMapping, because they assume mutable data.
As usual empty, dictionaries lie in a gray area. In general, List[str] and List[str]
are considered non-overlapping despite empty list belongs to both. However, List[int]
and List[Never] are considered overlapping.
Expand All @@ -1159,6 +1163,12 @@ def typed_dict_mapping_overlap(
assert isinstance(right, TypedDictType)
typed, other = right, left

mutable_mapping = next(
(base for base in other.type.mro if base.fullname == "typing.MutableMapping"), None
)
if mutable_mapping is not None and typed.readonly_keys:
return False

mapping = next(base for base in other.type.mro if base.fullname == "typing.Mapping")
other = map_instance_to_supertype(other, mapping)
key_type, value_type = get_proper_types(other.args)
Expand Down
20 changes: 17 additions & 3 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,17 @@ def invalid_index_type(
code=code,
)

def readonly_keys_mutated(self, keys: set[str], context: Context) -> None:
if len(keys) == 1:
suffix = "is"
else:
suffix = "are"
self.fail(
"ReadOnly {} TypedDict {} mutated".format(format_key_list(sorted(keys)), suffix),
code=codes.TYPEDDICT_READONLY_MUTATED,
context=context,
)

def too_few_arguments(
self, callee: CallableType, context: Context, argument_names: Sequence[str | None] | None
) -> None:
Expand Down Expand Up @@ -2613,10 +2624,13 @@ def format_literal_value(typ: LiteralType) -> str:
return format(typ.fallback)
items = []
for item_name, item_type in typ.items.items():
modifier = "" if item_name in typ.required_keys else "?"
modifier = ""
if item_name not in typ.required_keys:
modifier += "?"
if item_name in typ.readonly_keys:
modifier += "="
items.append(f"{item_name!r}{modifier}: {format(item_type)}")
s = f"TypedDict({{{', '.join(items)}}})"
return s
return f"TypedDict({{{', '.join(items)}}})"
elif isinstance(typ, LiteralType):
return f"Literal[{format_literal_value(typ)}]"
elif isinstance(typ, UnionType):
Expand Down
22 changes: 19 additions & 3 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import Callable
from typing import Callable, Final

import mypy.errorcodes as codes
from mypy import message_registry
Expand Down Expand Up @@ -372,6 +372,10 @@ def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
)
return AnyType(TypeOfAny.from_error)

assigned_readonly_keys = ctx.type.readonly_keys & set(keys)
if assigned_readonly_keys:
ctx.api.msg.readonly_keys_mutated(assigned_readonly_keys, context=ctx.context)

default_type = ctx.arg_types[1][0]

value_types = []
Expand Down Expand Up @@ -415,13 +419,16 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
return AnyType(TypeOfAny.from_error)

for key in keys:
if key in ctx.type.required_keys:
if key in ctx.type.required_keys or key in ctx.type.readonly_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
elif key not in ctx.type.items:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return ctx.default_return_type


_TP_DICT_MUTATING_METHODS: Final = frozenset({"update of TypedDict", "__ior__ of TypedDict"})


def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for methods that update `TypedDict`.
Expand All @@ -436,10 +443,19 @@ def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
arg_type = arg_type.as_anonymous()
arg_type = arg_type.copy_modified(required_keys=set())
if ctx.args and ctx.args[0]:
with ctx.api.msg.filter_errors():
if signature.name in _TP_DICT_MUTATING_METHODS:
# If we want to mutate this object in place, we need to set this flag,
# it will trigger an extra check in TypedDict's checker.
arg_type.to_be_mutated = True
with ctx.api.msg.filter_errors(
filter_errors=lambda name, info: info.code != codes.TYPEDDICT_READONLY_MUTATED,
save_filtered_errors=True,
):
inferred = get_proper_type(
ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type)
)
if arg_type.to_be_mutated:
arg_type.to_be_mutated = False # Done!
possible_tds = []
if isinstance(inferred, TypedDictType):
possible_tds = [inferred]
Expand Down
1 change: 1 addition & 0 deletions mypy/plugins/proper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def is_special_target(right: ProperType) -> bool:
"mypy.types.ErasedType",
"mypy.types.DeletedType",
"mypy.types.RequiredType",
"mypy.types.ReadOnlyType",
):
# Special case: these are not valid targets for a type alias and thus safe.
# TODO: introduce a SyntheticType base to simplify this?
Expand Down
8 changes: 4 additions & 4 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7169,7 +7169,7 @@ def type_analyzer(
allow_tuple_literal: bool = False,
allow_unbound_tvars: bool = False,
allow_placeholder: bool = False,
allow_required: bool = False,
allow_typed_dict_special_forms: bool = False,
allow_param_spec_literals: bool = False,
allow_unpack: bool = False,
report_invalid_types: bool = True,
Expand All @@ -7188,7 +7188,7 @@ def type_analyzer(
allow_tuple_literal=allow_tuple_literal,
report_invalid_types=report_invalid_types,
allow_placeholder=allow_placeholder,
allow_required=allow_required,
allow_typed_dict_special_forms=allow_typed_dict_special_forms,
allow_param_spec_literals=allow_param_spec_literals,
allow_unpack=allow_unpack,
prohibit_self_type=prohibit_self_type,
Expand All @@ -7211,7 +7211,7 @@ def anal_type(
allow_tuple_literal: bool = False,
allow_unbound_tvars: bool = False,
allow_placeholder: bool = False,
allow_required: bool = False,
allow_typed_dict_special_forms: bool = False,
allow_param_spec_literals: bool = False,
allow_unpack: bool = False,
report_invalid_types: bool = True,
Expand Down Expand Up @@ -7246,7 +7246,7 @@ def anal_type(
allow_unbound_tvars=allow_unbound_tvars,
allow_tuple_literal=allow_tuple_literal,
allow_placeholder=allow_placeholder,
allow_required=allow_required,
allow_typed_dict_special_forms=allow_typed_dict_special_forms,
allow_param_spec_literals=allow_param_spec_literals,
allow_unpack=allow_unpack,
report_invalid_types=report_invalid_types,
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def anal_type(
tvar_scope: TypeVarLikeScope | None = None,
allow_tuple_literal: bool = False,
allow_unbound_tvars: bool = False,
allow_required: bool = False,
allow_typed_dict_special_forms: bool = False,
allow_placeholder: bool = False,
report_invalid_types: bool = True,
prohibit_self_type: str | None = None,
Expand Down
Loading

0 comments on commit 6726d77

Please sign in to comment.