Skip to content

Backport evaluate_forward_ref() changes #611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
on Python versions <3.10. PEP 604 was introduced in Python 3.10, and
`typing_extensions` does not generally attempt to backport PEP-604 methods
to prior versions.
- Further update `typing_extensions.evaluate_forward_ref` with changes in Python 3.14.

# Release 4.14.0rc1 (May 24, 2025)

Expand Down
171 changes: 149 additions & 22 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8944,7 +8944,147 @@ def test_pep_695_generics_with_future_annotations_nested_in_function(self):
set(results.generic_func.__type_params__)
)

class TestEvaluateForwardRefs(BaseTestCase):

class EvaluateForwardRefTests(BaseTestCase):
def test_evaluate_forward_ref(self):
int_ref = typing_extensions.ForwardRef('int')
self.assertIs(typing_extensions.evaluate_forward_ref(int_ref), int)
self.assertIs(
typing_extensions.evaluate_forward_ref(int_ref, type_params=()),
int,
)
self.assertIs(
typing_extensions.evaluate_forward_ref(int_ref, format=typing_extensions.Format.VALUE),
int,
)
self.assertIs(
typing_extensions.evaluate_forward_ref(
int_ref, format=typing_extensions.Format.FORWARDREF,
),
int,
)
self.assertEqual(
typing_extensions.evaluate_forward_ref(
int_ref, format=typing_extensions.Format.STRING,
),
'int',
)

def test_evaluate_forward_ref_undefined(self):
missing = typing_extensions.ForwardRef('missing')
with self.assertRaises(NameError):
typing_extensions.evaluate_forward_ref(missing)
self.assertIs(
typing_extensions.evaluate_forward_ref(
missing, format=typing_extensions.Format.FORWARDREF,
),
missing,
)
self.assertEqual(
typing_extensions.evaluate_forward_ref(
missing, format=typing_extensions.Format.STRING,
),
"missing",
)

def test_evaluate_forward_ref_nested(self):
ref = typing_extensions.ForwardRef("Union[int, list['str']]")
ns = {"Union": Union}
if sys.version_info >= (3, 11):
expected = Union[int, list[str]]
else:
expected = Union[int, list['str']] # TODO: evaluate nested forward refs in Python < 3.11
self.assertEqual(
typing_extensions.evaluate_forward_ref(ref, globals=ns),
expected,
)
self.assertEqual(
typing_extensions.evaluate_forward_ref(
ref, globals=ns, format=typing_extensions.Format.FORWARDREF
),
expected,
)
self.assertEqual(
typing_extensions.evaluate_forward_ref(ref, format=typing_extensions.Format.STRING),
"Union[int, list['str']]",
)

why = typing_extensions.ForwardRef('"\'str\'"')
self.assertIs(typing_extensions.evaluate_forward_ref(why), str)

@skipUnless(sys.version_info >= (3, 10), "Relies on PEP 604")
def test_evaluate_forward_ref_nested_pep604(self):
ref = typing_extensions.ForwardRef("int | list['str']")
if sys.version_info >= (3, 11):
expected = int | list[str]
else:
expected = int | list['str'] # TODO: evaluate nested forward refs in Python < 3.11
self.assertEqual(
typing_extensions.evaluate_forward_ref(ref),
expected,
)
self.assertEqual(
typing_extensions.evaluate_forward_ref(ref, format=typing_extensions.Format.FORWARDREF),
expected,
)
self.assertEqual(
typing_extensions.evaluate_forward_ref(ref, format=typing_extensions.Format.STRING),
"int | list['str']",
)

def test_evaluate_forward_ref_none(self):
none_ref = typing_extensions.ForwardRef('None')
self.assertIs(typing_extensions.evaluate_forward_ref(none_ref), None)

def test_globals(self):
A = "str"
ref = typing_extensions.ForwardRef('list[A]')
with self.assertRaises(NameError):
typing_extensions.evaluate_forward_ref(ref)
self.assertEqual(
typing_extensions.evaluate_forward_ref(ref, globals={'A': A}),
list[str] if sys.version_info >= (3, 11) else list['str'],
)

def test_owner(self):
ref = typing_extensions.ForwardRef("A")

with self.assertRaises(NameError):
typing_extensions.evaluate_forward_ref(ref)

# We default to the globals of `owner`,
# so it no longer raises `NameError`
self.assertIs(
typing_extensions.evaluate_forward_ref(ref, owner=Loop), A
)

@skipUnless(sys.version_info >= (3, 14), "Not yet implemented in Python < 3.14")
def test_inherited_owner(self):
# owner passed to evaluate_forward_ref
ref = typing_extensions.ForwardRef("list['A']")
self.assertEqual(
typing_extensions.evaluate_forward_ref(ref, owner=Loop),
list[A],
)

# owner set on the ForwardRef
ref = typing_extensions.ForwardRef("list['A']", owner=Loop)
self.assertEqual(
typing_extensions.evaluate_forward_ref(ref),
list[A],
)

@skipUnless(sys.version_info >= (3, 14), "Not yet implemented in Python < 3.14")
def test_partial_evaluation(self):
ref = typing_extensions.ForwardRef("list[A]")
with self.assertRaises(NameError):
typing_extensions.evaluate_forward_ref(ref)

self.assertEqual(
typing_extensions.evaluate_forward_ref(ref, format=typing_extensions.Format.FORWARDREF),
list[EqualToForwardRef('A')],
)

def test_global_constant(self):
if sys.version_info[:3] > (3, 10, 0):
self.assertTrue(_FORWARD_REF_HAS_CLASS)
Expand Down Expand Up @@ -9107,30 +9247,17 @@ class Y(Generic[Tx]):
self.assertEqual(get_args(evaluated_ref3), (Z[str],))

def test_invalid_special_forms(self):
# tests _lax_type_check to raise errors the same way as the typing module.
# Regex capture "< class 'module.name'> and "module.name"
with self.assertRaisesRegex(
TypeError, r"Plain .*Protocol('>)? is not valid as type argument"
):
evaluate_forward_ref(typing.ForwardRef("Protocol"), globals=vars(typing))
with self.assertRaisesRegex(
TypeError, r"Plain .*Generic('>)? is not valid as type argument"
):
evaluate_forward_ref(typing.ForwardRef("Generic"), globals=vars(typing))
with self.assertRaisesRegex(TypeError, r"Plain typing(_extensions)?\.Final is not valid as type argument"):
evaluate_forward_ref(typing.ForwardRef("Final"), globals=vars(typing))
with self.assertRaisesRegex(TypeError, r"Plain typing(_extensions)?\.ClassVar is not valid as type argument"):
evaluate_forward_ref(typing.ForwardRef("ClassVar"), globals=vars(typing))
for name in ("Protocol", "Final", "ClassVar", "Generic"):
with self.subTest(name=name):
self.assertIs(
evaluate_forward_ref(typing.ForwardRef(name), globals=vars(typing)),
getattr(typing, name),
)
if _FORWARD_REF_HAS_CLASS:
self.assertIs(evaluate_forward_ref(typing.ForwardRef("Final", is_class=True), globals=vars(typing)), Final)
self.assertIs(evaluate_forward_ref(typing.ForwardRef("ClassVar", is_class=True), globals=vars(typing)), ClassVar)
with self.assertRaisesRegex(TypeError, r"Plain typing(_extensions)?\.Final is not valid as type argument"):
evaluate_forward_ref(typing.ForwardRef("Final", is_argument=False), globals=vars(typing))
with self.assertRaisesRegex(TypeError, r"Plain typing(_extensions)?\.ClassVar is not valid as type argument"):
evaluate_forward_ref(typing.ForwardRef("ClassVar", is_argument=False), globals=vars(typing))
else:
self.assertIs(evaluate_forward_ref(typing.ForwardRef("Final", is_argument=False), globals=vars(typing)), Final)
self.assertIs(evaluate_forward_ref(typing.ForwardRef("ClassVar", is_argument=False), globals=vars(typing)), ClassVar)
self.assertIs(evaluate_forward_ref(typing.ForwardRef("Final", is_argument=False), globals=vars(typing)), Final)
self.assertIs(evaluate_forward_ref(typing.ForwardRef("ClassVar", is_argument=False), globals=vars(typing)), ClassVar)


class TestSentinels(BaseTestCase):
Expand Down
85 changes: 8 additions & 77 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4060,57 +4060,6 @@ def _eval_with_owner(
forward_ref.__forward_value__ = value
return value

def _lax_type_check(
value, msg, is_argument=True, *, module=None, allow_special_forms=False
):
"""
A lax Python 3.11+ like version of typing._type_check
"""
if hasattr(typing, "_type_convert"):
if (
sys.version_info >= (3, 10, 3)
or (3, 9, 10) < sys.version_info[:3] < (3, 10)
):
# allow_special_forms introduced later cpython/#30926 (bpo-46539)
type_ = typing._type_convert(
value,
module=module,
allow_special_forms=allow_special_forms,
)
# module was added with bpo-41249 before is_class (bpo-46539)
elif "__forward_module__" in typing.ForwardRef.__slots__:
type_ = typing._type_convert(value, module=module)
else:
type_ = typing._type_convert(value)
else:
if value is None:
return type(None)
if isinstance(value, str):
return ForwardRef(value)
type_ = value
invalid_generic_forms = (Generic, Protocol)
if not allow_special_forms:
invalid_generic_forms += (ClassVar,)
if is_argument:
invalid_generic_forms += (Final,)
if (
isinstance(type_, typing._GenericAlias)
and get_origin(type_) in invalid_generic_forms
):
raise TypeError(f"{type_} is not valid as type argument") from None
if type_ in (Any, LiteralString, NoReturn, Never, Self, TypeAlias):
return type_
if allow_special_forms and type_ in (ClassVar, Final):
return type_
if (
isinstance(type_, (_SpecialForm, typing._SpecialForm))
or type_ in (Generic, Protocol)
):
raise TypeError(f"Plain {type_} is not valid as type argument") from None
if type(type_) is tuple: # lax version with tuple instead of callable
raise TypeError(f"{msg} Got {type_!r:.100}.")
return type_

def evaluate_forward_ref(
forward_ref,
*,
Expand Down Expand Up @@ -4163,24 +4112,15 @@ def evaluate_forward_ref(
else:
raise

msg = "Forward references must evaluate to types."
if not _FORWARD_REF_HAS_CLASS:
allow_special_forms = not forward_ref.__forward_is_argument__
else:
allow_special_forms = forward_ref.__forward_is_class__
type_ = _lax_type_check(
value,
msg,
is_argument=forward_ref.__forward_is_argument__,
allow_special_forms=allow_special_forms,
)
if isinstance(value, str):
value = ForwardRef(value)

# Recursively evaluate the type
if isinstance(type_, ForwardRef):
if getattr(type_, "__forward_module__", True) is not None:
if isinstance(value, ForwardRef):
if getattr(value, "__forward_module__", True) is not None:
globals = None
return evaluate_forward_ref(
type_,
value,
globals=globals,
locals=locals,
type_params=type_params, owner=owner,
Expand All @@ -4194,28 +4134,19 @@ def evaluate_forward_ref(
locals[tvar.__name__] = tvar
if sys.version_info < (3, 12, 5):
return typing._eval_type(
type_,
value,
globals,
locals,
recursive_guard=_recursive_guard | {forward_ref.__forward_arg__},
)
if sys.version_info < (3, 14):
else:
return typing._eval_type(
type_,
value,
globals,
locals,
type_params,
recursive_guard=_recursive_guard | {forward_ref.__forward_arg__},
)
return typing._eval_type(
type_,
globals,
locals,
type_params,
recursive_guard=_recursive_guard | {forward_ref.__forward_arg__},
format=format,
owner=owner,
)


class Sentinel:
Expand Down
Loading