|
2 | 2 |
|
3 | 3 | import os |
4 | 4 | import inspect |
| 5 | +import weakref |
5 | 6 | from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast |
6 | 7 | from datetime import date, datetime |
7 | 8 | from typing_extensions import ( |
@@ -598,6 +599,9 @@ class CachedDiscriminatorType(Protocol): |
598 | 599 | __discriminator__: DiscriminatorDetails |
599 | 600 |
|
600 | 601 |
|
| 602 | +DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary() |
| 603 | + |
| 604 | + |
601 | 605 | class DiscriminatorDetails: |
602 | 606 | field_name: str |
603 | 607 | """The name of the discriminator field in the variant class, e.g. |
@@ -640,8 +644,9 @@ def __init__( |
640 | 644 |
|
641 | 645 |
|
642 | 646 | def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: |
643 | | - if isinstance(union, CachedDiscriminatorType): |
644 | | - return union.__discriminator__ |
| 647 | + cached = DISCRIMINATOR_CACHE.get(union) |
| 648 | + if cached is not None: |
| 649 | + return cached |
645 | 650 |
|
646 | 651 | discriminator_field_name: str | None = None |
647 | 652 |
|
@@ -694,7 +699,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, |
694 | 699 | discriminator_field=discriminator_field_name, |
695 | 700 | discriminator_alias=discriminator_alias, |
696 | 701 | ) |
697 | | - cast(CachedDiscriminatorType, union).__discriminator__ = details |
| 702 | + DISCRIMINATOR_CACHE.setdefault(union, details) |
698 | 703 | return details |
699 | 704 |
|
700 | 705 |
|
|
0 commit comments