Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for including pydantic computed fields #3798

Merged
merged 6 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Release type: minor

Adds the ability to include pydantic computed fields when using pydantic.type decorator.

Example:
```python
class UserModel(pydantic.BaseModel):
age: int

@computed_field
@property
def next_age(self) -> int:
return self.age + 1


@strawberry.experimental.pydantic.type(
UserModel, all_fields=True, include_computed=True
)
class User:
pass
```

Will allow `nextAge` to be requested from a user entity.
16 changes: 16 additions & 0 deletions docs/integrations/pydantic.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ class UserType:
pass
```

By default, computed fields are excluded. To also include all computed fields
pass `include_computed=True` to the decorator.

```python
import strawberry

from .models import User


@strawberry.experimental.pydantic.type(
model=User, all_fields=True, include_computed=True
)
class UserType:
pass
```

## Input types

Input types are similar to types; we can create one by using the
Expand Down
38 changes: 34 additions & 4 deletions strawberry/experimental/pydantic/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from strawberry.experimental.pydantic.exceptions import UnsupportedTypeError

if TYPE_CHECKING:
from pydantic.fields import FieldInfo
from pydantic.fields import ComputedFieldInfo, FieldInfo

IS_PYDANTIC_V2: bool = PYDANTIC_VERSION.startswith("2.")
IS_PYDANTIC_V1: bool = not IS_PYDANTIC_V2
Expand Down Expand Up @@ -128,7 +128,33 @@

return PydanticUndefined

def get_model_fields(self, model: type[BaseModel]) -> dict[str, CompatModelField]:
def get_model_computed_fields(
self, model: type[BaseModel]
) -> dict[str, CompatModelField]:
computed_field_info: dict[str, ComputedFieldInfo] = model.model_computed_fields
new_fields = {}

Check warning on line 135 in strawberry/experimental/pydantic/_compat.py

View check run for this annotation

Codecov / codecov/patch

strawberry/experimental/pydantic/_compat.py#L134-L135

Added lines #L134 - L135 were not covered by tests
# Convert it into CompatModelField
for name, field in computed_field_info.items():
new_fields[name] = CompatModelField(

Check warning on line 138 in strawberry/experimental/pydantic/_compat.py

View check run for this annotation

Codecov / codecov/patch

strawberry/experimental/pydantic/_compat.py#L138

Added line #L138 was not covered by tests
name=name,
type_=field.return_type,
outer_type_=field.return_type,
default=None,
default_factory=None,
required=False,
alias=field.alias,
# v2 doesn't have allow_none
allow_none=False,
has_alias=field is not None,
description=field.description,
_missing_type=self.PYDANTIC_MISSING_TYPE,
is_v1=False,
)
return new_fields

Check warning on line 153 in strawberry/experimental/pydantic/_compat.py

View check run for this annotation

Codecov / codecov/patch

strawberry/experimental/pydantic/_compat.py#L153

Added line #L153 was not covered by tests

def get_model_fields(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider extracting the shared field conversion logic into a helper function with a flag to handle computed fields differently, and then call this helper from both get_model_fields and get_model_computed_fields functions.

Consider extracting the shared conversion logic into a single helper that toggles computed behavior via a flag. This reduces duplication while keeping both functionalities intact. For example:

def _convert_field(self, name: str, field: Any, *, computed: bool = False) -> CompatModelField:
    type_val = field.return_type if computed else field.annotation
    return CompatModelField(
        name=name,
        type_=type_val,
        outer_type_=type_val,
        default=None if computed else field.default,
        default_factory=None if computed else field.default_factory,  # type: ignore
        required=False if computed else field.is_required(),
        alias=field.alias,
        allow_none=False,
        has_alias=field is not None,
        description=field.description,
        _missing_type=self.PYDANTIC_MISSING_TYPE,
        is_v1=False,
    )

Then update your methods as follows:

def get_model_fields(self, model: type[BaseModel], include_computed: bool = False
) -> dict[str, CompatModelField]:
    basic_fields = {
        name: self._convert_field(name, field)
        for name, field in model.model_fields.items()
    }
    if include_computed:
        computed_fields = {
            name: self._convert_field(name, field, computed=True)
            for name, field in model.model_computed_fields.items()
        }
        basic_fields |= computed_fields
    return basic_fields

def get_model_computed_fields(self, model: type[BaseModel]) -> dict[str, CompatModelField]:
    return {
        name: self._convert_field(name, field, computed=True)
        for name, field in model.model_computed_fields.items()
    }

This approach maintains current functionality and reduces complexity by consolidating repeated logic.

self, model: type[BaseModel], include_computed: bool = False
) -> dict[str, CompatModelField]:
field_info: dict[str, FieldInfo] = model.model_fields
new_fields = {}
# Convert it into CompatModelField
Expand All @@ -148,6 +174,8 @@
_missing_type=self.PYDANTIC_MISSING_TYPE,
is_v1=False,
)
if include_computed:
new_fields |= self.get_model_computed_fields(model)

Check warning on line 178 in strawberry/experimental/pydantic/_compat.py

View check run for this annotation

Codecov / codecov/patch

strawberry/experimental/pydantic/_compat.py#L178

Added line #L178 was not covered by tests
return new_fields

@cached_property
Expand Down Expand Up @@ -175,7 +203,10 @@
def PYDANTIC_MISSING_TYPE(self) -> Any: # noqa: N802
return dataclasses.MISSING

def get_model_fields(self, model: type[BaseModel]) -> dict[str, CompatModelField]:
def get_model_fields(
self, model: type[BaseModel], include_computed: bool = False
) -> dict[str, CompatModelField]:
"""`include_computed` is a noop for PydanticV1Compat."""
new_fields = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
new_fields = {}
del include_computed
new_fields = {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to keep linters happy :)

# Convert it into CompatModelField
for name, field in model.__fields__.items(): # type: ignore[attr-defined]
Expand Down Expand Up @@ -284,7 +315,6 @@
smart_deepcopy,
)


__all__ = [
"PydanticCompat",
"get_args",
Expand Down
8 changes: 6 additions & 2 deletions strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,12 @@ def type( # noqa: PLR0915
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
all_fields: bool = False,
include_computed: bool = False,
use_pydantic_alias: bool = True,
) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]:
def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: # noqa: PLR0915
compat = PydanticCompat.from_model(model)
model_fields = compat.get_model_fields(model)
model_fields = compat.get_model_fields(model, include_computed=include_computed)
original_fields_set = set(fields) if fields else set()

if fields:
Expand Down Expand Up @@ -171,7 +172,10 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
raise MissingFieldsListError(cls)

ensure_all_auto_fields_in_pydantic(
model=model, auto_fields=auto_fields_set, cls_name=cls.__name__
model=model,
auto_fields=auto_fields_set,
cls_name=cls.__name__,
include_computed=include_computed,
)

wrapped = _wrap_dataclass(cls)
Expand Down
10 changes: 8 additions & 2 deletions strawberry/experimental/pydantic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,17 @@ def get_default_factory_for_field(


def ensure_all_auto_fields_in_pydantic(
model: type[BaseModel], auto_fields: set[str], cls_name: str
model: type[BaseModel],
auto_fields: set[str],
cls_name: str,
include_computed: bool = False,
) -> None:
compat = PydanticCompat.from_model(model)
# Raise error if user defined a strawberry.auto field not present in the model
non_existing_fields = list(auto_fields - compat.get_model_fields(model).keys())
non_existing_fields = list(
auto_fields
- compat.get_model_fields(model, include_computed=include_computed).keys()
)

if non_existing_fields:
raise AutoFieldsNotInBaseModelError(
Expand Down
72 changes: 72 additions & 0 deletions tests/experimental/pydantic/schema/test_computed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import textwrap

import pydantic
import pytest
from pydantic.version import VERSION as PYDANTIC_VERSION

import strawberry

IS_PYDANTIC_V2: bool = PYDANTIC_VERSION.startswith("2.")

if IS_PYDANTIC_V2:
from pydantic import computed_field

Check warning on line 12 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L12

Added line #L12 was not covered by tests


@pytest.mark.skipif(
not IS_PYDANTIC_V2, reason="Requires Pydantic v2 for computed_field"
)
def test_computed_field():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're using sets it won't be an issue, but it might be nice to add a test that has include_computed=True and also declares the computed field next_age with strawberry.auto to make sure we don't have duplication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pydantic doesn't seem to like me declaring both a computed field and another field, with or without the strawberry.auto type -- for example for

    class UserModel(pydantic.BaseModel):
        age: int
        next_age: strawberry.auto

        @computed_field
        @property
        def next_age(self) -> int:
            return self.age + 1

I get

ValueError: you can't override a field with a computed field

If you were thinking of something different though, let me know, happy to write another test!

Copy link
Contributor

@skilkis skilkis Mar 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh sorry for the confusion I meant something more along the lines of:

@strawberry.experimental.pydantic.type(
    UserModel, all_fields=True, include_computed=True
)
class User:
    next_age: strawberry.auto

This will touch the include_computed flow as well as the regular field specifier flow. Currently this test should pass fine as we are using a set to dedupe the fields, the purpose would be to check if ever this behavior changes in the future.

Not the most important test though for this PR as this can occur with or without computed fields, I'll already approve the PR, thanks again for contributing this! 🚀

class UserModel(pydantic.BaseModel):
age: int

Check warning on line 20 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L19-L20

Added lines #L19 - L20 were not covered by tests

@computed_field
@property
def next_age(self) -> int:
return self.age + 1

Check warning on line 25 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L22-L25

Added lines #L22 - L25 were not covered by tests

@strawberry.experimental.pydantic.type(

Check warning on line 27 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L27

Added line #L27 was not covered by tests
UserModel, all_fields=True, include_computed=True
)
class User:
pass

Check warning on line 31 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L30-L31

Added lines #L30 - L31 were not covered by tests

@strawberry.experimental.pydantic.type(UserModel, all_fields=True)
class UserNoComputed:
pass

Check warning on line 35 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L33-L35

Added lines #L33 - L35 were not covered by tests

@strawberry.type
class Query:
@strawberry.field
def user(self) -> User:
return User.from_pydantic(UserModel(age=1))

Check warning on line 41 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L37-L41

Added lines #L37 - L41 were not covered by tests

@strawberry.field
def user_no_computed(self) -> UserNoComputed:
return UserNoComputed.from_pydantic(UserModel(age=1))

Check warning on line 45 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L43-L45

Added lines #L43 - L45 were not covered by tests

schema = strawberry.Schema(query=Query)

Check warning on line 47 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L47

Added line #L47 was not covered by tests

expected_schema = """

Check warning on line 49 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L49

Added line #L49 was not covered by tests
type Query {
user: User!
userNoComputed: UserNoComputed!
}

type User {
age: Int!
nextAge: Int!
}

type UserNoComputed {
age: Int!
}
"""

assert str(schema) == textwrap.dedent(expected_schema).strip()

Check warning on line 65 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L65

Added line #L65 was not covered by tests

query = "{ user { age nextAge } }"

Check warning on line 67 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L67

Added line #L67 was not covered by tests

result = schema.execute_sync(query)
assert not result.errors
assert result.data["user"]["age"] == 1
assert result.data["user"]["nextAge"] == 2

Check warning on line 72 in tests/experimental/pydantic/schema/test_computed.py

View check run for this annotation

Codecov / codecov/patch

tests/experimental/pydantic/schema/test_computed.py#L69-L72

Added lines #L69 - L72 were not covered by tests
Comment on lines +69 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add test verifying the behavior when include_computed is False

It's important to explicitly test the negative case where include_computed=False. Verify that the computed field is not included in the schema and is not resolvable.

Suggested implementation:

    assert result.data["user"]["age"] == 1
    assert result.data["user"]["nextAge"] == 2

    # New test for behavior when include_computed is False
    def test_computed_not_included_when_disabled():
        # Assuming build_schema accepts an include_computed flag.
        schema = build_schema(include_computed=False)

        expected_schema = """
        type UserNoComputed {
          age: Int!
        }
        """
        # Check that the printed schema does not include computed fields.
        assert str(schema) == textwrap.dedent(expected_schema).strip()

        query = "{ user { age nextAge } }"

        # Execute a query that attempts to fetch the computed field.
        result = schema.execute_sync(query)

        # We expect an error since "nextAge" should not be resolvable.
        assert result.errors
        # Optionally, verify the error message indicates nextAge is not a valid field.
        assert "nextAge" in result.errors[0].message

Note:

  1. Ensure that the function build_schema exists and accepts the include_computed flag.
  2. Adjust the error message check as needed to match the actual error output from your GraphQL library.
  3. If you prefer, you could wrap the new test in a dedicated test class or module that groups computed field tests.

Loading