Skip to content

Commit a74e69f

Browse files
bpo-44653: Support typing.Union in parameter substitution of the union type
1 parent 8f50f44 commit a74e69f

File tree

3 files changed

+57
-4
lines changed

3 files changed

+57
-4
lines changed

Lib/test/test_types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,21 @@ def test_union_parameter_chaining(self):
772772
self.assertEqual((list[T] | list[S])[int, T], list[int] | list[T])
773773
self.assertEqual((list[T] | list[S])[int, int], list[int])
774774

775+
def test_union_parameter_substitution_union(self):
776+
T = typing.TypeVar("T")
777+
res = (int | T)[str | list]
778+
self.assertEqual(res, int | str | list)
779+
self.assertIsInstance(res, types.Union)
780+
res = (int | T)[str | int]
781+
self.assertEqual(res, int | str)
782+
self.assertIsInstance(res, types.Union)
783+
res = (int | T)[typing.Union[str, list]]
784+
self.assertEqual(res, int | str | list)
785+
self.assertIsInstance(res, types.Union)
786+
res = (int | T)[typing.Union[str, int]]
787+
self.assertEqual(res, int | str)
788+
self.assertIsInstance(res, types.Union)
789+
775790
def test_union_parameter_substitution_errors(self):
776791
T = typing.TypeVar("T")
777792
x = int | T
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support ``typing.Union`` in parameter substitution of the union type.

Objects/unionobject.c

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,12 @@ is_new_type(PyObject *obj)
273273
return is_typing_module(obj);
274274
}
275275

276+
static int
277+
is_typing_union(PyObject *obj)
278+
{
279+
return is_typing_name(obj, "_UnionGenericAlias");
280+
}
281+
276282
// Emulates short-circuiting behavior of the ``||`` operator
277283
// while also checking negative values.
278284
#define CHECK_RES(res) { \
@@ -429,6 +435,19 @@ static PyMethodDef union_methods[] = {
429435
{0}};
430436

431437

438+
static PyObject *
439+
from_typing_union(PyObject *obj)
440+
{
441+
_Py_IDENTIFIER(__args__);
442+
PyObject *args = _PyObject_GetAttrId(obj, &PyId___args__);
443+
if (args == NULL) {
444+
return NULL;
445+
}
446+
PyObject *result = make_union(args);
447+
Py_DECREF(args);
448+
return result;
449+
}
450+
432451
static PyObject *
433452
union_getitem(PyObject *self, PyObject *item)
434453
{
@@ -447,16 +466,34 @@ union_getitem(PyObject *self, PyObject *item)
447466
}
448467

449468
// Check arguments are unionable.
469+
assert(Py_REFCNT(newargs) == 1);
450470
Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
451471
for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
452472
PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
453-
int is_arg_unionable = is_unionable(arg);
454-
if (is_arg_unionable <= 0) {
455-
Py_DECREF(newargs);
456-
if (is_arg_unionable == 0) {
473+
int r = is_unionable(arg);
474+
if (r == 0) {
475+
r = is_typing_union(arg);
476+
if (r == 0) {
457477
PyErr_Format(PyExc_TypeError,
458478
"Each union argument must be a type, got %.100R", arg);
479+
Py_DECREF(newargs);
480+
return NULL;
481+
}
482+
if (r > 0) {
483+
// Replace typing.Union with types.Union.
484+
PyObject *newarg = from_typing_union(arg);
485+
if (newarg == NULL) {
486+
Py_DECREF(newargs);
487+
return NULL;
488+
}
489+
assert(Py_REFCNT(newargs) == 1);
490+
assert(PyTuple_GET_ITEM(newargs, iarg) == arg);
491+
PyTuple_SET_ITEM(newargs, iarg, newarg);
492+
Py_DECREF(arg);
459493
}
494+
}
495+
if (r < 0) {
496+
Py_DECREF(newargs);
460497
return NULL;
461498
}
462499
}

0 commit comments

Comments
 (0)