Skip to content

Commit 80d2b66

Browse files
committed
Allow Discriminator for discriminator in Field
1 parent 5611bda commit 80d2b66

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

sqlmodel/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
__version__ = "0.0.31"
22

33
# Re-export from SQLAlchemy
4+
from pydantic import Discriminator as Discriminator
5+
from pydantic import Tag as Tag
6+
7+
# Re-export from Pydantic
48
from sqlalchemy.engine import create_engine as create_engine
59
from sqlalchemy.engine import create_mock_engine as create_mock_engine
610
from sqlalchemy.engine import engine_from_config as engine_from_config

sqlmodel/main.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
overload,
2323
)
2424

25-
from pydantic import BaseModel, EmailStr
25+
from pydantic import BaseModel, Discriminator, EmailStr
2626
from pydantic.fields import FieldInfo as PydanticFieldInfo
2727
from sqlalchemy import (
2828
Boolean,
@@ -228,7 +228,7 @@ def Field(
228228
max_length: Optional[int] = None,
229229
allow_mutation: bool = True,
230230
regex: Optional[str] = None,
231-
discriminator: Optional[str] = None,
231+
discriminator: Union[str, Discriminator, None] = None,
232232
repr: bool = True,
233233
primary_key: Union[bool, UndefinedType] = Undefined,
234234
foreign_key: Any = Undefined,
@@ -271,7 +271,7 @@ def Field(
271271
max_length: Optional[int] = None,
272272
allow_mutation: bool = True,
273273
regex: Optional[str] = None,
274-
discriminator: Optional[str] = None,
274+
discriminator: Union[str, Discriminator, None] = None,
275275
repr: bool = True,
276276
primary_key: Union[bool, UndefinedType] = Undefined,
277277
foreign_key: str,
@@ -323,7 +323,7 @@ def Field(
323323
max_length: Optional[int] = None,
324324
allow_mutation: bool = True,
325325
regex: Optional[str] = None,
326-
discriminator: Optional[str] = None,
326+
discriminator: Union[str, Discriminator, None] = None,
327327
repr: bool = True,
328328
sa_column: Union[Column[Any], UndefinedType] = Undefined,
329329
schema_extra: Optional[dict[str, Any]] = None,
@@ -356,7 +356,7 @@ def Field(
356356
max_length: Optional[int] = None,
357357
allow_mutation: bool = True,
358358
regex: Optional[str] = None,
359-
discriminator: Optional[str] = None,
359+
discriminator: Union[str, Discriminator, None] = None,
360360
repr: bool = True,
361361
primary_key: Union[bool, UndefinedType] = Undefined,
362362
foreign_key: Any = Undefined,

tests/test_pydantic/test_field.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from decimal import Decimal
2-
from typing import Literal, Optional, Union
2+
from typing import Annotated, Any, Literal, Optional, Union
33

44
import pytest
55
from pydantic import ValidationError
6-
from sqlmodel import Field, SQLModel
6+
from sqlmodel import Discriminator, Field, SQLModel, Tag
77

88

99
def test_decimal():
@@ -47,6 +47,39 @@ class Model(SQLModel):
4747
Model(pet={"pet_type": "dog"}, n=1) # type: ignore[arg-type]
4848

4949

50+
def test_discriminator_callable():
51+
# Example adapted from
52+
# [Pydantic docs](https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator):
53+
54+
class Pie(SQLModel):
55+
pass
56+
57+
class ApplePie(Pie):
58+
fruit: Literal["apple"] = "apple"
59+
60+
class PumpkinPie(Pie):
61+
filling: Literal["pumpkin"] = "pumpkin"
62+
63+
def get_discriminator_value(v: Any) -> str:
64+
if isinstance(v, dict):
65+
return v.get("fruit", v.get("filling"))
66+
return getattr(v, "fruit", getattr(v, "filling", None))
67+
68+
class ThanksgivingDinner(SQLModel):
69+
dessert: Union[
70+
Annotated[ApplePie, Tag("apple")],
71+
Annotated[PumpkinPie, Tag("pumpkin")],
72+
] = Field(
73+
discriminator=Discriminator(get_discriminator_value),
74+
)
75+
76+
apple_pie = ThanksgivingDinner.model_validate({"dessert": {"fruit": "apple"}})
77+
assert isinstance(apple_pie.dessert, ApplePie)
78+
79+
pumpkin_pie = ThanksgivingDinner.model_validate({"dessert": {"filling": "pumpkin"}})
80+
assert isinstance(pumpkin_pie.dessert, PumpkinPie)
81+
82+
5083
def test_repr():
5184
class Model(SQLModel):
5285
id: Optional[int] = Field(primary_key=True)

0 commit comments

Comments
 (0)