Skip to content

Backport PEP-696 specialisation on Python >=3.11.1 #397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6402,6 +6402,34 @@ def test_typevartuple(self):
class A(Generic[Unpack[Ts]]): ...
Alias = Optional[Unpack[Ts]]

@skipIf(
sys.version_info < (3, 11, 1),
"Not yet backported for older versions of Python"
)
Comment on lines +6405 to +6408
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified locally that these tests pass with Python 3.11.1 but not with Python 3.11.0...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the failures are just that we're missing the defaults from __args__, which doesn't seem too bad. (It may cause problems for introspection but shouldn't usually break creation of types that are normally in annotations.) I'm fine leaving those out for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the test failures "just" indicate incorrect behaviour rather than crashes

def test_typevartuple_specialization(self):
T = TypeVar("T")
Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]])
self.assertEqual(Ts.__default__, Unpack[Tuple[str, int]])
class A(Generic[T, Unpack[Ts]]): ...
self.assertEqual(A[float].__args__, (float, str, int))
self.assertEqual(A[float, range].__args__, (float, range))
self.assertEqual(A[float, Unpack[tuple[int, ...]]].__args__, (float, Unpack[tuple[int, ...]]))

@skipIf(
sys.version_info < (3, 11, 1),
"Not yet backported for older versions of Python"
)
def test_typevar_and_typevartuple_specialization(self):
T = TypeVar("T")
U = TypeVar("U", default=float)
Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]])
self.assertEqual(Ts.__default__, Unpack[Tuple[str, int]])
class A(Generic[T, U, Unpack[Ts]]): ...
self.assertEqual(A[int].__args__, (int, float, str, int))
self.assertEqual(A[int, str].__args__, (int, str, str, int))
self.assertEqual(A[int, str, range].__args__, (int, str, range))
self.assertEqual(A[int, str, Unpack[tuple[int, ...]]].__args__, (int, str, Unpack[tuple[int, ...]]))

def test_no_default_after_typevar_tuple(self):
T = TypeVar("T", default=int)
Ts = TypeVarTuple("Ts")
Expand Down Expand Up @@ -6487,6 +6515,46 @@ def test_allow_default_after_non_default_in_alias(self):
a4 = Callable[[Unpack[Ts]], T]
self.assertEqual(a4.__args__, (Unpack[Ts], T))

@skipIf(
sys.version_info < (3, 11, 1),
"Not yet backported for older versions of Python"
)
def test_paramspec_specialization(self):
T = TypeVar("T")
P = ParamSpec('P', default=[str, int])
self.assertEqual(P.__default__, [str, int])
class A(Generic[T, P]): ...
self.assertEqual(A[float].__args__, (float, (str, int)))
self.assertEqual(A[float, [range]].__args__, (float, (range,)))

@skipIf(
sys.version_info < (3, 11, 1),
"Not yet backported for older versions of Python"
)
def test_typevar_and_paramspec_specialization(self):
T = TypeVar("T")
U = TypeVar("U", default=float)
P = ParamSpec('P', default=[str, int])
self.assertEqual(P.__default__, [str, int])
class A(Generic[T, U, P]): ...
self.assertEqual(A[float].__args__, (float, float, (str, int)))
self.assertEqual(A[float, int].__args__, (float, int, (str, int)))
self.assertEqual(A[float, int, [range]].__args__, (float, int, (range,)))

@skipIf(
sys.version_info < (3, 11, 1),
"Not yet backported for older versions of Python"
)
def test_paramspec_and_typevar_specialization(self):
T = TypeVar("T")
P = ParamSpec('P', default=[str, int])
U = TypeVar("U", default=float)
self.assertEqual(P.__default__, [str, int])
class A(Generic[T, P, U]): ...
self.assertEqual(A[float].__args__, (float, (str, int), float))
self.assertEqual(A[float, [range]].__args__, (float, (range,), float))
self.assertEqual(A[float, [range], int].__args__, (float, (range,), int))


class NoDefaultTests(BaseTestCase):
@skip_if_py313_beta_1
Expand Down
97 changes: 97 additions & 0 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,8 +1513,19 @@ def __new__(cls, name, *constraints, bound=None,
if infer_variance and (covariant or contravariant):
raise ValueError("Variance cannot be specified with infer_variance.")
typevar.__infer_variance__ = infer_variance

_set_default(typevar, default)
_set_module(typevar)

def _tvar_prepare_subst(alias, args):
if (
typevar.has_default()
and alias.__parameters__.index(typevar) == len(args)
):
args += (typevar.__default__,)
return args

typevar.__typing_prepare_subst__ = _tvar_prepare_subst
return typevar

def __init_subclass__(cls) -> None:
Expand Down Expand Up @@ -1613,6 +1624,24 @@ def __new__(cls, name, *, bound=None,

_set_default(paramspec, default)
_set_module(paramspec)

def _paramspec_prepare_subst(alias, args):
params = alias.__parameters__
i = params.index(paramspec)
if i == len(args) and paramspec.has_default():
args = [*args, paramspec.__default__]
if i >= len(args):
raise TypeError(f"Too few arguments for {alias}")
# Special case where Z[[int, str, bool]] == Z[int, str, bool] in PEP 612.
if len(params) == 1 and not typing._is_param_expr(args[0]):
assert i == 0
args = (args,)
# Convert lists to tuples to help other libraries cache the results.
elif isinstance(args[i], list):
args = (*args[:i], tuple(args[i]), *args[i + 1:])
return args

paramspec.__typing_prepare_subst__ = _paramspec_prepare_subst
return paramspec

def __init_subclass__(cls) -> None:
Expand Down Expand Up @@ -2311,6 +2340,17 @@ def __init__(self, getitem):
class _UnpackAlias(typing._GenericAlias, _root=True):
__class__ = typing.TypeVar

@property
def __typing_unpacked_tuple_args__(self):
assert self.__origin__ is Unpack
assert len(self.__args__) == 1
arg, = self.__args__
if isinstance(arg, (typing._GenericAlias, _types.GenericAlias)):
if arg.__origin__ is not tuple:
raise TypeError("Unpack[...] must be used with a tuple type")
return arg.__args__
return None

@_UnpackSpecialForm
def Unpack(self, parameters):
item = typing._type_check(parameters, f'{self._name} accepts only a single type.')
Expand Down Expand Up @@ -2340,6 +2380,16 @@ def _is_unpack(obj):

elif hasattr(typing, "TypeVarTuple"): # 3.11+

def _unpack_args(*args):
newargs = []
for arg in args:
subargs = getattr(arg, '__typing_unpacked_tuple_args__', None)
if subargs is not None and not (subargs and subargs[-1] is ...):
newargs.extend(subargs)
else:
newargs.append(arg)
return newargs

# Add default parameter - PEP 696
class TypeVarTuple(metaclass=_TypeVarLikeMeta):
"""Type variable tuple."""
Expand All @@ -2350,6 +2400,53 @@ def __new__(cls, name, *, default=NoDefault):
tvt = typing.TypeVarTuple(name)
_set_default(tvt, default)
_set_module(tvt)

def _typevartuple_prepare_subst(alias, args):
params = alias.__parameters__
typevartuple_index = params.index(tvt)
for param in params[typevartuple_index + 1:]:
if isinstance(param, TypeVarTuple):
raise TypeError(
f"More than one TypeVarTuple parameter in {alias}"
)

alen = len(args)
plen = len(params)
left = typevartuple_index
right = plen - typevartuple_index - 1
var_tuple_index = None
fillarg = None
for k, arg in enumerate(args):
if not isinstance(arg, type):
subargs = getattr(arg, '__typing_unpacked_tuple_args__', None)
if subargs and len(subargs) == 2 and subargs[-1] is ...:
if var_tuple_index is not None:
raise TypeError(
"More than one unpacked "
"arbitrary-length tuple argument"
)
var_tuple_index = k
fillarg = subargs[0]
if var_tuple_index is not None:
left = min(left, var_tuple_index)
right = min(right, alen - var_tuple_index - 1)
elif left + right > alen:
raise TypeError(f"Too few arguments for {alias};"
f" actual {alen}, expected at least {plen - 1}")
if left == alen - right and tvt.has_default():
replacement = _unpack_args(tvt.__default__)
else:
replacement = args[left: alen - right]

return (
*args[:left],
*([fillarg] * (typevartuple_index - left)),
replacement,
*([fillarg] * (plen - right - left - typevartuple_index - 1)),
*args[alen - right:],
)

tvt.__typing_prepare_subst__ = _typevartuple_prepare_subst
return tvt

def __init_subclass__(self, *args, **kwds):
Expand Down