Skip to content

Commit 15a7ec8

Browse files
fix: compat with Python 3.14
1 parent afc14f2 commit 15a7ec8

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

src/openai/_models.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import inspect
5+
import weakref
56
from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast
67
from datetime import date, datetime
78
from typing_extensions import (
@@ -598,6 +599,9 @@ class CachedDiscriminatorType(Protocol):
598599
__discriminator__: DiscriminatorDetails
599600

600601

602+
DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary()
603+
604+
601605
class DiscriminatorDetails:
602606
field_name: str
603607
"""The name of the discriminator field in the variant class, e.g.
@@ -640,8 +644,9 @@ def __init__(
640644

641645

642646
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
643-
if isinstance(union, CachedDiscriminatorType):
644-
return union.__discriminator__
647+
cached = DISCRIMINATOR_CACHE.get(union)
648+
if cached is not None:
649+
return cached
645650

646651
discriminator_field_name: str | None = None
647652

@@ -694,7 +699,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
694699
discriminator_field=discriminator_field_name,
695700
discriminator_alias=discriminator_alias,
696701
)
697-
cast(CachedDiscriminatorType, union).__discriminator__ = details
702+
DISCRIMINATOR_CACHE.setdefault(union, details)
698703
return details
699704

700705

tests/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from openai._utils import PropertyInfo
1111
from openai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
12-
from openai._models import BaseModel, construct_type
12+
from openai._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
1313

1414

1515
class BasicModel(BaseModel):
@@ -809,7 +809,7 @@ class B(BaseModel):
809809

810810
UnionType = cast(Any, Union[A, B])
811811

812-
assert not hasattr(UnionType, "__discriminator__")
812+
assert not DISCRIMINATOR_CACHE.get(UnionType)
813813

814814
m = construct_type(
815815
value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
@@ -818,7 +818,7 @@ class B(BaseModel):
818818
assert m.type == "b"
819819
assert m.data == "foo" # type: ignore[comparison-overlap]
820820

821-
discriminator = UnionType.__discriminator__
821+
discriminator = DISCRIMINATOR_CACHE.get(UnionType)
822822
assert discriminator is not None
823823

824824
m = construct_type(
@@ -830,7 +830,7 @@ class B(BaseModel):
830830

831831
# if the discriminator details object stays the same between invocations then
832832
# we hit the cache
833-
assert UnionType.__discriminator__ is discriminator
833+
assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator
834834

835835

836836
@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")

0 commit comments

Comments
 (0)