Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(schema): semantic nullability draft #3722

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion strawberry/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
StrawberryList,
StrawberryObjectDefinition,
StrawberryOptional,
StrawberryRequired,
StrawberryTypeVar,
get_object_definition,
has_object_definition,
Expand All @@ -33,6 +34,7 @@
from strawberry.types.lazy_type import LazyType
from strawberry.types.private import is_private
from strawberry.types.scalar import ScalarDefinition
from strawberry.types.strict_non_null import STRAWBERRY_REQUIRED_TOKEN
from strawberry.types.unset import UNSET
from strawberry.utils.typing import eval_type, is_generic, is_type_var

Expand All @@ -41,7 +43,6 @@
from strawberry.types.field import StrawberryField
from strawberry.types.union import StrawberryUnion


ASYNC_TYPES = (
abc.AsyncGenerator,
abc.AsyncIterable,
Expand Down Expand Up @@ -160,6 +161,8 @@ def _resolve(self) -> Union[StrawberryType, type]:
return self.create_enum(evaled_type)
elif self._is_optional(evaled_type, args):
return self.create_optional(evaled_type)
elif self._is_required(evaled_type, args):
return self.create_required(evaled_type)
elif self._is_union(evaled_type, args):
return self.create_union(evaled_type, args)
elif is_type_var(evaled_type) or evaled_type is Self:
Expand Down Expand Up @@ -221,6 +224,14 @@ def create_optional(self, evaled_type: Any) -> StrawberryOptional:

return StrawberryOptional(of_type)

def create_required(self, evaled_type: Any) -> StrawberryRequired:
of_type = StrawberryAnnotation(
annotation=evaled_type,
namespace=self.namespace,
).resolve()

return StrawberryRequired(of_type)

def create_type_var(self, evaled_type: TypeVar) -> StrawberryTypeVar:
return StrawberryTypeVar(evaled_type)

Expand Down Expand Up @@ -300,6 +311,11 @@ def _is_optional(cls, annotation: Any, args: List[Any]) -> bool:
# A Union to be optional needs to have at least one None type
return any(x is type(None) for x in types)

@classmethod
def _is_required(cls, annotation: Any, args: List[Any]) -> bool:
"""Returns True if the annotation is Required[SomeType]."""
return STRAWBERRY_REQUIRED_TOKEN in args

@classmethod
def _is_list(cls, annotation: Any) -> bool:
"""Returns True if annotation is a List."""
Expand Down
8 changes: 8 additions & 0 deletions strawberry/exceptions/semantic_nullability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from graphql.error.graphql_error import GraphQLError


class InvalidNullReturnError(GraphQLError):
def __init__(self) -> None:
super().__init__(
message="Expected non-null return type for semanticNonNull field, but got null",
)
33 changes: 33 additions & 0 deletions strawberry/field_extensions/semantic_nullability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from strawberry.exceptions.semantic_nullability import InvalidNullReturnError
from strawberry.extensions import FieldExtension

if TYPE_CHECKING:
from strawberry.extensions.field_extension import (
AsyncExtensionResolver,
SyncExtensionResolver,
)
from strawberry.types import Info


class SemanticNonNullExtension(FieldExtension):
def resolve(
self, next_: SyncExtensionResolver, source: Any, info: Info, **kwargs: Any
) -> Any:
resolved = next_(source, info, **kwargs)
if resolved is not None:
return resolved
else:
raise InvalidNullReturnError()

async def resolve_async(
self, next_: AsyncExtensionResolver, source: Any, info: Info, **kwargs: Any
) -> Any:
resolved = await next_(source, info, **kwargs)
if resolved is not None:
return resolved
else:
raise InvalidNullReturnError()
1 change: 1 addition & 0 deletions strawberry/schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class StrawberryConfig:
default_resolver: Callable[[Any, str], object] = getattr
relay_max_results: int = 100
disable_field_suggestions: bool = False
semantic_nullability_beta: bool = False
info_class: type[Info] = Info

def __post_init__(
Expand Down
4 changes: 3 additions & 1 deletion strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ class Query:
if has_object_definition(type_):
if type_.__strawberry_definition__.is_graphql_generic:
type_ = StrawberryAnnotation(type_).resolve() # noqa: PLW2901
graphql_type = self.schema_converter.from_maybe_optional(type_)
graphql_type = self.schema_converter.from_maybe_optional_or_required(
type_
)
if isinstance(graphql_type, GraphQLNonNull):
graphql_type = graphql_type.of_type
if not isinstance(graphql_type, GraphQLNamedType):
Expand Down
52 changes: 43 additions & 9 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
StrawberryList,
StrawberryObjectDefinition,
StrawberryOptional,
StrawberryRequired,
StrawberryType,
get_object_definition,
has_object_definition,
Expand All @@ -71,7 +72,9 @@
from strawberry.types.unset import UNSET
from strawberry.utils.await_maybe import await_maybe

from ..exceptions.semantic_nullability import InvalidNullReturnError
from ..extensions.field_extension import build_field_extension_resolvers
from ..types.semantic_non_null import SemanticNonNull
from . import compat
from .types.concrete_type import ConcreteType

Expand Down Expand Up @@ -252,7 +255,7 @@ def __init__(

def from_argument(self, argument: StrawberryArgument) -> GraphQLArgument:
argument_type = cast(
"GraphQLInputType", self.from_maybe_optional(argument.type)
"GraphQLInputType", self.from_maybe_optional_or_required(argument.type)
)
default_value = Undefined if argument.default is UNSET else argument.default

Expand Down Expand Up @@ -367,6 +370,9 @@ def from_schema_directive(self, cls: Type) -> GraphQLDirective:
},
)

def create_semantic_non_null_directive(self) -> SemanticNonNull:
return SemanticNonNull()

def from_field(
self,
field: StrawberryField,
Expand All @@ -378,12 +384,30 @@ def from_field(
resolver = self.from_resolver(field)
field_type = cast(
"GraphQLOutputType",
self.from_maybe_optional(
self.from_maybe_optional_or_required(
field.resolve_type(type_definition=type_definition)
),
)
subscribe = None

maybe_wrapped_resolver = resolver
if self.config.semantic_nullability_beta and getattr(
field_type, "strawberry_semantic_required_non_null", False
):
# TODO: Mid-Term, we'd want something like this:
# field.extensions.append(SemanticNonNullExtension()) #noqa
# However, we need to refactor from_resolver into two parts:
# - field type modifications (call extension.apply)
# - resolver building, so we can add the new semanticnonnullextension resolver later
field.directives.append(self.create_semantic_non_null_directive())

def resolver_wrapper(*args: Any, **kwargs: Any): # noqa
result = resolver(*args, **kwargs)
if result is None:
raise InvalidNullReturnError()

maybe_wrapped_resolver = resolver_wrapper

if field.is_subscription:
subscribe = resolver
resolver = lambda event, *_, **__: event # noqa: E731
Expand All @@ -396,7 +420,7 @@ def from_field(
return GraphQLField(
type_=field_type,
args=graphql_arguments,
resolve=resolver,
resolve=maybe_wrapped_resolver,
subscribe=subscribe,
description=field.description,
deprecation_reason=field.deprecation_reason,
Expand All @@ -413,7 +437,7 @@ def from_input_field(
) -> GraphQLInputField:
field_type = cast(
"GraphQLInputType",
self.from_maybe_optional(
self.from_maybe_optional_or_required(
field.resolve_type(type_definition=type_definition)
),
)
Expand Down Expand Up @@ -586,7 +610,7 @@ def resolve_type(
return graphql_interface

def from_list(self, type_: StrawberryList) -> GraphQLList:
of_type = self.from_maybe_optional(type_.of_type)
of_type = self.from_maybe_optional_or_required(type_.of_type)

return GraphQLList(of_type)

Expand Down Expand Up @@ -804,16 +828,26 @@ def from_scalar(self, scalar: Type) -> GraphQLScalarType:

return implementation

def from_maybe_optional(
def from_maybe_optional_or_required(
self, type_: Union[StrawberryType, type]
) -> Union[GraphQLNullableType, GraphQLNonNull]:
# ) -> Union[GraphQLNullableType, GraphQLNonNull, GraphQLSemanticNonNull]: TODO in the future this will include graphql semantic non null
NoneType = type(None)
if type_ is None or type_ is NoneType:
return self.from_type(type_)
graphql_core_type = self.from_type(type_)
elif isinstance(type_, StrawberryOptional):
return self.from_type(type_.of_type)
graphql_core_type = self.from_type(type_.of_type)
elif isinstance(type_, StrawberryRequired):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Missing return statement in StrawberryRequired branch

The code sets strawberry_semantic_required_non_null but doesn't return graphql_core_type, causing it to fall through to the else clause. Add 'return graphql_core_type' after setting the attribute.

graphql_core_type = GraphQLNonNull(self.from_type(type_.of_type))
graphql_core_type.strawberry_semantic_required_non_null = False
else:
return GraphQLNonNull(self.from_type(type_))
if self.config.semantic_nullability_beta:
graphql_core_type = self.from_type(type_)
else:
graphql_core_type = GraphQLNonNull(self.from_type(type_))
graphql_core_type.strawberry_semantic_required_non_null = True

return graphql_core_type

def from_type(self, type_: Union[StrawberryType, type]) -> GraphQLNullableType:
if compat.is_graphql_generic(type_):
Expand Down
3 changes: 3 additions & 0 deletions strawberry/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ class StrawberryList(StrawberryContainer): ...
class StrawberryOptional(StrawberryContainer): ...


class StrawberryRequired(StrawberryContainer): ...


class StrawberryTypeVar(StrawberryType):
def __init__(self, type_var: TypeVar) -> None:
self.type_var = type_var
Expand Down
12 changes: 12 additions & 0 deletions strawberry/types/semantic_non_null.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Optional

from strawberry.schema_directive import Location, schema_directive


@schema_directive(
locations=[Location.FIELD_DEFINITION],
name="semanticNonNull",
print_definition=True,
)
class SemanticNonNull:
level: Optional[int] = None
8 changes: 8 additions & 0 deletions strawberry/types/strict_non_null.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import TypeVar
from typing_extensions import Annotated

STRAWBERRY_REQUIRED_TOKEN = "strawberry_required"

T = TypeVar("T")

NonNull = Annotated[T, STRAWBERRY_REQUIRED_TOKEN]
60 changes: 60 additions & 0 deletions tests/schema/test_semantic_nullability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import textwrap
from typing import List, Optional

import pytest

import strawberry
from strawberry.exceptions.semantic_nullability import InvalidNullReturnError
from strawberry.schema.config import StrawberryConfig
from strawberry.types.strict_non_null import NonNull


def test_semantic_nullability_enabled():
@strawberry.type
class Product:
upc: str
name: str
price: NonNull[int]
weight: Optional[int]

@strawberry.type
class Query:
@strawberry.field
def top_products(self, first: int) -> List[Product]:
return []

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(semantic_nullability_beta=True)
)

expected_sdl = textwrap.dedent("""
directive @semanticNonNull(level: Int = null) on FIELD_DEFINITION | OBJECT | INTERFACE | SCALAR | ENUM

type Product {
upc: String @semanticNonNull(level: null)
name: String @semanticNonNull(level: null)
price: Int!
weight: Int @semanticNonNull(level: null)
}

type Query {
topProducts(first: Int): [Product] @semanticNonNull(level: null)
}
""").strip()

assert str(schema) == expected_sdl


def test_semantic_nullability_error_on_null():
@strawberry.type
class Query:
@strawberry.field
def greeting(self) -> str:
return None

schema = strawberry.Schema(
query=Query, config=StrawberryConfig(semantic_nullability_beta=True)
)

with pytest.raises(InvalidNullReturnError):
result = schema.execute_sync("{ greeting }")
Loading