Skip to content

Commit

Permalink
Fix introspection issue when using pickle (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Nov 2, 2022
1 parent b0c6973 commit 0a01ce8
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 112 deletions.
40 changes: 17 additions & 23 deletions src/graphql/type/definition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations # Python < 3.10

import warnings
from enum import Enum
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -232,6 +231,23 @@ class GraphQLNamedType(GraphQLType):
ast_node: Optional[TypeDefinitionNode]
extension_ast_nodes: Tuple[TypeExtensionNode, ...]

reserved_types: Dict[str, "GraphQLNamedType"] = {}

def __new__(cls, name: str, *_args: Any, **_kwargs: Any) -> "GraphQLNamedType":
if name in cls.reserved_types:
raise TypeError(f"Redefinition of reserved type {name!r}")
return super().__new__(cls)

def __reduce__(self) -> Tuple[Callable, Tuple]:
return self._get_instance, (self.name, tuple(self.to_kwargs().items()))

@classmethod
def _get_instance(cls, name: str, args: Tuple) -> "GraphQLNamedType":
try:
return cls.reserved_types[name]
except KeyError:
return cls(**dict(args))

def __init__(
self,
name: str,
Expand Down Expand Up @@ -348,28 +364,6 @@ def serialize_odd(value: Any) -> int:
ast_node: Optional[ScalarTypeDefinitionNode]
extension_ast_nodes: Tuple[ScalarTypeExtensionNode, ...]

specified_types: Mapping[str, GraphQLScalarType] = {}

def __new__(cls, name: str, *args: Any, **kwargs: Any) -> "GraphQLScalarType":
if name and name in cls.specified_types:
warnings.warn(
f"Redefinition of specified scalar type {name!r}",
RuntimeWarning,
stacklevel=2,
)
return cls.specified_types[name]
return super().__new__(cls)

def __reduce__(self) -> Tuple[Callable, Tuple]:
return self._get_instance, (self.name, tuple(self.to_kwargs().items()))

@classmethod
def _get_instance(cls, name: str, args: Tuple) -> "GraphQLScalarType":
try:
return cls.specified_types[name]
except KeyError:
return cls(**dict(args))

def __init__(
self,
name: str,
Expand Down
4 changes: 4 additions & 0 deletions src/graphql/type/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,3 +684,7 @@ def type_name(_source, info, **_args):
def is_introspection_type(type_: GraphQLNamedType) -> bool:
"""Check whether the given named GraphQL type is an introspection type."""
return type_.name in introspection_types


# register the introspection types to avoid redefinition
GraphQLNamedType.reserved_types.update(introspection_types)
4 changes: 2 additions & 2 deletions src/graphql/type/scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,5 +322,5 @@ def is_specified_scalar_type(type_: GraphQLNamedType) -> bool:
return type_.name in specified_scalar_types


# store the specified instances as class attribute to avoid redefinition
GraphQLScalarType.specified_types = specified_scalar_types
# register the scalar types to avoid redefinition
GraphQLNamedType.reserved_types.update(specified_scalar_types)
46 changes: 27 additions & 19 deletions src/graphql/utilities/build_client_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def build_scalar_def(
) -> GraphQLScalarType:
name = scalar_introspection["name"]
try:
return GraphQLScalarType.specified_types[name]
return cast(GraphQLScalarType, GraphQLScalarType.reserved_types[name])
except KeyError:
return GraphQLScalarType(
name=name,
Expand Down Expand Up @@ -165,12 +165,16 @@ def build_implementations_list(
def build_object_def(
object_introspection: IntrospectionObjectType,
) -> GraphQLObjectType:
return GraphQLObjectType(
name=object_introspection["name"],
description=object_introspection.get("description"),
interfaces=lambda: build_implementations_list(object_introspection),
fields=lambda: build_field_def_map(object_introspection),
)
name = object_introspection["name"]
try:
return cast(GraphQLObjectType, GraphQLObjectType.reserved_types[name])
except KeyError:
return GraphQLObjectType(
name=name,
description=object_introspection.get("description"),
interfaces=lambda: build_implementations_list(object_introspection),
fields=lambda: build_field_def_map(object_introspection),
)

def build_interface_def(
interface_introspection: IntrospectionInterfaceType,
Expand Down Expand Up @@ -204,18 +208,22 @@ def build_enum_def(enum_introspection: IntrospectionEnumType) -> GraphQLEnumType
"Introspection result missing enumValues:"
f" {inspect(enum_introspection)}."
)
return GraphQLEnumType(
name=enum_introspection["name"],
description=enum_introspection.get("description"),
values={
value_introspect["name"]: GraphQLEnumValue(
value=value_introspect["name"],
description=value_introspect.get("description"),
deprecation_reason=value_introspect.get("deprecationReason"),
)
for value_introspect in enum_introspection["enumValues"]
},
)
name = enum_introspection["name"]
try:
return cast(GraphQLEnumType, GraphQLEnumType.reserved_types[name])
except KeyError:
return GraphQLEnumType(
name=name,
description=enum_introspection.get("description"),
values={
value_introspect["name"]: GraphQLEnumValue(
value=value_introspect["name"],
description=value_introspect.get("description"),
deprecation_reason=value_introspect.get("deprecationReason"),
)
for value_introspect in enum_introspection["enumValues"]
},
)

def build_input_object_def(
input_object_introspection: IntrospectionInputObjectType,
Expand Down
9 changes: 9 additions & 0 deletions tests/type/test_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
GraphQLScalarType,
GraphQLString,
GraphQLUnionType,
introspection_types,
)


Expand Down Expand Up @@ -1915,3 +1916,11 @@ def fields_have_repr():
repr(GraphQLField(GraphQLList(GraphQLInt)))
== "<GraphQLField <GraphQLList <GraphQLScalarType 'Int'>>>"
)


def describe_type_system_introspection_types():
def cannot_redefine_introspection_types():
for name, introspection_type in introspection_types.items():
assert introspection_type.name == name
with raises(TypeError, match=f"Redefinition of reserved type '{name}'"):
introspection_type.__class__(**introspection_type.to_kwargs())
37 changes: 11 additions & 26 deletions tests/type/test_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from math import inf, nan, pi
from typing import Any

from pytest import raises, warns
from pytest import raises

from graphql.error import GraphQLError
from graphql.language import parse_value as parse_value_to_ast
Expand Down Expand Up @@ -175,11 +175,8 @@ def serializes():
assert str(exc_info.value) == "Int cannot represent non-integer value: [5]"

def cannot_be_redefined():
with warns(
RuntimeWarning, match="Redefinition of specified scalar type 'Int'"
):
redefined_int = GraphQLScalarType(name="Int")
assert redefined_int == GraphQLInt
with raises(TypeError, match="Redefinition of reserved type 'Int'"):
GraphQLScalarType(name="Int")

def pickles():
assert pickle.loads(pickle.dumps(GraphQLInt)) is GraphQLInt
Expand Down Expand Up @@ -308,11 +305,8 @@ def serializes():
)

def cannot_be_redefined():
with warns(
RuntimeWarning, match="Redefinition of specified scalar type 'Float'"
):
redefined_float = GraphQLScalarType(name="Float")
assert redefined_float == GraphQLFloat
with raises(TypeError, match="Redefinition of reserved type 'Float'"):
GraphQLScalarType(name="Float")

def pickles():
assert pickle.loads(pickle.dumps(GraphQLFloat)) is GraphQLFloat
Expand Down Expand Up @@ -424,11 +418,8 @@ def __str__(self):
)

def cannot_be_redefined():
with warns(
RuntimeWarning, match="Redefinition of specified scalar type 'String'"
):
redefined_string = GraphQLScalarType(name="String")
assert redefined_string == GraphQLString
with raises(TypeError, match="Redefinition of reserved type 'String'"):
GraphQLScalarType(name="String")

def pickles():
assert pickle.loads(pickle.dumps(GraphQLString)) is GraphQLString
Expand Down Expand Up @@ -576,11 +567,8 @@ def serializes():
)

def cannot_be_redefined():
with warns(
RuntimeWarning, match="Redefinition of specified scalar type 'Boolean'"
):
redefined_boolean = GraphQLScalarType(name="Boolean")
assert redefined_boolean == GraphQLBoolean
with raises(TypeError, match="Redefinition of reserved type 'Boolean'"):
GraphQLScalarType(name="Boolean")

def pickles():
assert pickle.loads(pickle.dumps(GraphQLBoolean)) is GraphQLBoolean
Expand Down Expand Up @@ -707,11 +695,8 @@ def __str__(self):
assert str(exc_info.value) == "ID cannot represent value: ['abc']"

def cannot_be_redefined():
with warns(
RuntimeWarning, match="Redefinition of specified scalar type 'ID'"
):
redefined_id = GraphQLScalarType(name="ID")
assert redefined_id == GraphQLID
with raises(TypeError, match="Redefinition of reserved type 'ID'"):
GraphQLScalarType(name="ID")

def pickles():
assert pickle.loads(pickle.dumps(GraphQLID)) is GraphQLID
7 changes: 4 additions & 3 deletions tests/type/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GraphQLInt,
GraphQLInterfaceType,
GraphQLList,
GraphQLNamedType,
GraphQLObjectType,
GraphQLScalarType,
GraphQLSchema,
Expand Down Expand Up @@ -332,14 +333,14 @@ def check_that_query_mutation_and_subscription_are_graphql_types():
def describe_a_schema_must_contain_uniquely_named_types():
def rejects_a_schema_which_redefines_a_built_in_type():
# temporarily allow redefinition of the String scalar type
specified_types = GraphQLScalarType.specified_types
GraphQLScalarType.specified_types = {}
reserved_types = GraphQLNamedType.reserved_types
GraphQLScalarType.reserved_types = {}
try:
# create a redefined String scalar type
FakeString = GraphQLScalarType("String")
finally:
# protect from redefinition again
GraphQLScalarType.specified_types = specified_types
GraphQLScalarType.reserved_types = reserved_types

QueryType = GraphQLObjectType(
"Query",
Expand Down
Loading

0 comments on commit 0a01ce8

Please sign in to comment.