diff --git a/src/graphql/type/definition.py b/src/graphql/type/definition.py index e6ef8d41..d9fe289d 100644 --- a/src/graphql/type/definition.py +++ b/src/graphql/type/definition.py @@ -1,6 +1,5 @@ from __future__ import annotations # Python < 3.10 -import warnings from enum import Enum from typing import ( TYPE_CHECKING, @@ -232,6 +231,23 @@ class GraphQLNamedType(GraphQLType): ast_node: Optional[TypeDefinitionNode] extension_ast_nodes: Tuple[TypeExtensionNode, ...] + reserved_types: Dict[str, "GraphQLNamedType"] = {} + + def __new__(cls, name: str, *_args: Any, **_kwargs: Any) -> "GraphQLNamedType": + if name in cls.reserved_types: + raise TypeError(f"Redefinition of reserved type {name!r}") + return super().__new__(cls) + + def __reduce__(self) -> Tuple[Callable, Tuple]: + return self._get_instance, (self.name, tuple(self.to_kwargs().items())) + + @classmethod + def _get_instance(cls, name: str, args: Tuple) -> "GraphQLNamedType": + try: + return cls.reserved_types[name] + except KeyError: + return cls(**dict(args)) + def __init__( self, name: str, @@ -348,28 +364,6 @@ def serialize_odd(value: Any) -> int: ast_node: Optional[ScalarTypeDefinitionNode] extension_ast_nodes: Tuple[ScalarTypeExtensionNode, ...] - specified_types: Mapping[str, GraphQLScalarType] = {} - - def __new__(cls, name: str, *args: Any, **kwargs: Any) -> "GraphQLScalarType": - if name and name in cls.specified_types: - warnings.warn( - f"Redefinition of specified scalar type {name!r}", - RuntimeWarning, - stacklevel=2, - ) - return cls.specified_types[name] - return super().__new__(cls) - - def __reduce__(self) -> Tuple[Callable, Tuple]: - return self._get_instance, (self.name, tuple(self.to_kwargs().items())) - - @classmethod - def _get_instance(cls, name: str, args: Tuple) -> "GraphQLScalarType": - try: - return cls.specified_types[name] - except KeyError: - return cls(**dict(args)) - def __init__( self, name: str, diff --git a/src/graphql/type/introspection.py b/src/graphql/type/introspection.py index 0121f7bc..17922d21 100644 --- a/src/graphql/type/introspection.py +++ b/src/graphql/type/introspection.py @@ -684,3 +684,7 @@ def type_name(_source, info, **_args): def is_introspection_type(type_: GraphQLNamedType) -> bool: """Check whether the given named GraphQL type is an introspection type.""" return type_.name in introspection_types + + +# register the introspection types to avoid redefinition +GraphQLNamedType.reserved_types.update(introspection_types) diff --git a/src/graphql/type/scalars.py b/src/graphql/type/scalars.py index 270bca75..3f7263c1 100644 --- a/src/graphql/type/scalars.py +++ b/src/graphql/type/scalars.py @@ -322,5 +322,5 @@ def is_specified_scalar_type(type_: GraphQLNamedType) -> bool: return type_.name in specified_scalar_types -# store the specified instances as class attribute to avoid redefinition -GraphQLScalarType.specified_types = specified_scalar_types +# register the scalar types to avoid redefinition +GraphQLNamedType.reserved_types.update(specified_scalar_types) diff --git a/src/graphql/utilities/build_client_schema.py b/src/graphql/utilities/build_client_schema.py index 30974c3b..1f6694b1 100644 --- a/src/graphql/utilities/build_client_schema.py +++ b/src/graphql/utilities/build_client_schema.py @@ -136,7 +136,7 @@ def build_scalar_def( ) -> GraphQLScalarType: name = scalar_introspection["name"] try: - return GraphQLScalarType.specified_types[name] + return cast(GraphQLScalarType, GraphQLScalarType.reserved_types[name]) except KeyError: return GraphQLScalarType( name=name, @@ -165,12 +165,16 @@ def build_implementations_list( def build_object_def( object_introspection: IntrospectionObjectType, ) -> GraphQLObjectType: - return GraphQLObjectType( - name=object_introspection["name"], - description=object_introspection.get("description"), - interfaces=lambda: build_implementations_list(object_introspection), - fields=lambda: build_field_def_map(object_introspection), - ) + name = object_introspection["name"] + try: + return cast(GraphQLObjectType, GraphQLObjectType.reserved_types[name]) + except KeyError: + return GraphQLObjectType( + name=name, + description=object_introspection.get("description"), + interfaces=lambda: build_implementations_list(object_introspection), + fields=lambda: build_field_def_map(object_introspection), + ) def build_interface_def( interface_introspection: IntrospectionInterfaceType, @@ -204,18 +208,22 @@ def build_enum_def(enum_introspection: IntrospectionEnumType) -> GraphQLEnumType "Introspection result missing enumValues:" f" {inspect(enum_introspection)}." ) - return GraphQLEnumType( - name=enum_introspection["name"], - description=enum_introspection.get("description"), - values={ - value_introspect["name"]: GraphQLEnumValue( - value=value_introspect["name"], - description=value_introspect.get("description"), - deprecation_reason=value_introspect.get("deprecationReason"), - ) - for value_introspect in enum_introspection["enumValues"] - }, - ) + name = enum_introspection["name"] + try: + return cast(GraphQLEnumType, GraphQLEnumType.reserved_types[name]) + except KeyError: + return GraphQLEnumType( + name=name, + description=enum_introspection.get("description"), + values={ + value_introspect["name"]: GraphQLEnumValue( + value=value_introspect["name"], + description=value_introspect.get("description"), + deprecation_reason=value_introspect.get("deprecationReason"), + ) + for value_introspect in enum_introspection["enumValues"] + }, + ) def build_input_object_def( input_object_introspection: IntrospectionInputObjectType, diff --git a/tests/type/test_definition.py b/tests/type/test_definition.py index 28df74da..24973086 100644 --- a/tests/type/test_definition.py +++ b/tests/type/test_definition.py @@ -44,6 +44,7 @@ GraphQLScalarType, GraphQLString, GraphQLUnionType, + introspection_types, ) @@ -1915,3 +1916,11 @@ def fields_have_repr(): repr(GraphQLField(GraphQLList(GraphQLInt))) == ">>" ) + + +def describe_type_system_introspection_types(): + def cannot_redefine_introspection_types(): + for name, introspection_type in introspection_types.items(): + assert introspection_type.name == name + with raises(TypeError, match=f"Redefinition of reserved type '{name}'"): + introspection_type.__class__(**introspection_type.to_kwargs()) diff --git a/tests/type/test_scalars.py b/tests/type/test_scalars.py index 8b148e73..f2a45a67 100644 --- a/tests/type/test_scalars.py +++ b/tests/type/test_scalars.py @@ -2,7 +2,7 @@ from math import inf, nan, pi from typing import Any -from pytest import raises, warns +from pytest import raises from graphql.error import GraphQLError from graphql.language import parse_value as parse_value_to_ast @@ -175,11 +175,8 @@ def serializes(): assert str(exc_info.value) == "Int cannot represent non-integer value: [5]" def cannot_be_redefined(): - with warns( - RuntimeWarning, match="Redefinition of specified scalar type 'Int'" - ): - redefined_int = GraphQLScalarType(name="Int") - assert redefined_int == GraphQLInt + with raises(TypeError, match="Redefinition of reserved type 'Int'"): + GraphQLScalarType(name="Int") def pickles(): assert pickle.loads(pickle.dumps(GraphQLInt)) is GraphQLInt @@ -308,11 +305,8 @@ def serializes(): ) def cannot_be_redefined(): - with warns( - RuntimeWarning, match="Redefinition of specified scalar type 'Float'" - ): - redefined_float = GraphQLScalarType(name="Float") - assert redefined_float == GraphQLFloat + with raises(TypeError, match="Redefinition of reserved type 'Float'"): + GraphQLScalarType(name="Float") def pickles(): assert pickle.loads(pickle.dumps(GraphQLFloat)) is GraphQLFloat @@ -424,11 +418,8 @@ def __str__(self): ) def cannot_be_redefined(): - with warns( - RuntimeWarning, match="Redefinition of specified scalar type 'String'" - ): - redefined_string = GraphQLScalarType(name="String") - assert redefined_string == GraphQLString + with raises(TypeError, match="Redefinition of reserved type 'String'"): + GraphQLScalarType(name="String") def pickles(): assert pickle.loads(pickle.dumps(GraphQLString)) is GraphQLString @@ -576,11 +567,8 @@ def serializes(): ) def cannot_be_redefined(): - with warns( - RuntimeWarning, match="Redefinition of specified scalar type 'Boolean'" - ): - redefined_boolean = GraphQLScalarType(name="Boolean") - assert redefined_boolean == GraphQLBoolean + with raises(TypeError, match="Redefinition of reserved type 'Boolean'"): + GraphQLScalarType(name="Boolean") def pickles(): assert pickle.loads(pickle.dumps(GraphQLBoolean)) is GraphQLBoolean @@ -707,11 +695,8 @@ def __str__(self): assert str(exc_info.value) == "ID cannot represent value: ['abc']" def cannot_be_redefined(): - with warns( - RuntimeWarning, match="Redefinition of specified scalar type 'ID'" - ): - redefined_id = GraphQLScalarType(name="ID") - assert redefined_id == GraphQLID + with raises(TypeError, match="Redefinition of reserved type 'ID'"): + GraphQLScalarType(name="ID") def pickles(): assert pickle.loads(pickle.dumps(GraphQLID)) is GraphQLID diff --git a/tests/type/test_schema.py b/tests/type/test_schema.py index 9222cc85..efd44f86 100644 --- a/tests/type/test_schema.py +++ b/tests/type/test_schema.py @@ -20,6 +20,7 @@ GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNamedType, GraphQLObjectType, GraphQLScalarType, GraphQLSchema, @@ -332,14 +333,14 @@ def check_that_query_mutation_and_subscription_are_graphql_types(): def describe_a_schema_must_contain_uniquely_named_types(): def rejects_a_schema_which_redefines_a_built_in_type(): # temporarily allow redefinition of the String scalar type - specified_types = GraphQLScalarType.specified_types - GraphQLScalarType.specified_types = {} + reserved_types = GraphQLNamedType.reserved_types + GraphQLScalarType.reserved_types = {} try: # create a redefined String scalar type FakeString = GraphQLScalarType("String") finally: # protect from redefinition again - GraphQLScalarType.specified_types = specified_types + GraphQLScalarType.reserved_types = reserved_types QueryType = GraphQLObjectType( "Query", diff --git a/tests/utilities/test_build_ast_schema.py b/tests/utilities/test_build_ast_schema.py index 5bffa79a..bb0dc561 100644 --- a/tests/utilities/test_build_ast_schema.py +++ b/tests/utilities/test_build_ast_schema.py @@ -1,4 +1,5 @@ import pickle +import sys from collections import namedtuple from copy import deepcopy from typing import Union @@ -1190,46 +1191,45 @@ def rejects_invalid_ast(): assert str(exc_info.value) == "Must provide valid Document AST." def describe_deepcopy_and_pickle(): # pragma: no cover - star_wars_sdl = print_schema(star_wars_schema) + sdl = print_schema(star_wars_schema) - def can_deep_copy_star_wars_schema(): - # create a schema from the star wars SDL - schema = build_schema(star_wars_sdl, assume_valid_sdl=True) + def can_deep_copy_schema(): + schema = build_schema(sdl, assume_valid_sdl=True) # create a deepcopy of the schema copied = deepcopy(schema) # check that printing the copied schema gives the same SDL - assert print_schema(copied) == star_wars_sdl + assert print_schema(copied) == sdl def can_pickle_and_unpickle_star_wars_schema(): # create a schema from the star wars SDL - schema = build_schema(star_wars_sdl, assume_valid_sdl=True) + schema = build_schema(sdl, assume_valid_sdl=True) # check that the schema can be pickled # (particularly, there should be no recursion error, # or errors because of trying to pickle lambdas or local functions) dumped = pickle.dumps(schema) # check that the pickle size is reasonable - assert len(dumped) < 25 * len(star_wars_sdl) + assert len(dumped) < 25 * len(sdl) loaded = pickle.loads(dumped) # check that printing the unpickled schema gives the same SDL - assert print_schema(loaded) == star_wars_sdl + assert print_schema(loaded) == sdl # check that pickling again creates the same result dumped = pickle.dumps(schema) - assert len(dumped) < 25 * len(star_wars_sdl) + assert len(dumped) < 25 * len(sdl) loaded = pickle.loads(dumped) - assert print_schema(loaded) == star_wars_sdl + assert print_schema(loaded) == sdl - def can_deep_copy_pickled_star_wars_schema(): + def can_deep_copy_pickled_schema(): # create a schema from the star wars SDL - schema = build_schema(star_wars_sdl, assume_valid_sdl=True) + schema = build_schema(sdl, assume_valid_sdl=True) # pickle and unpickle the schema loaded = pickle.loads(pickle.dumps(schema)) # create a deepcopy of the unpickled schema copied = deepcopy(loaded) # check that printing the copied schema gives the same SDL - assert print_schema(copied) == star_wars_sdl + assert print_schema(copied) == sdl @mark.slow def describe_deepcopy_and_pickle_big(): # pragma: no cover @@ -1250,36 +1250,50 @@ def can_pickle_and_unpickle_big_schema(big_schema_sdl): # noqa: F811 # use our printing conventions big_schema_sdl = cycle_sdl(big_schema_sdl) - # create a schema from the big SDL - schema = build_schema(big_schema_sdl, assume_valid_sdl=True) - # check that the schema can be pickled - # (particularly, there should be no recursion error, - # or errors because of trying to pickle lambdas or local functions) - dumped = pickle.dumps(schema) + limit = sys.getrecursionlimit() + sys.setrecursionlimit(max(limit, 4000)) # needed for pickle - # check that the pickle size is reasonable - assert len(dumped) < 25 * len(big_schema_sdl) - loaded = pickle.loads(dumped) + try: + # create a schema from the big SDL + schema = build_schema(big_schema_sdl, assume_valid_sdl=True) + # check that the schema can be pickled + # (particularly, there should be no recursion error, + # or errors because of trying to pickle lambdas or local functions) + dumped = pickle.dumps(schema) - # check that printing the unpickled schema gives the same SDL - assert print_schema(loaded) == big_schema_sdl + # check that the pickle size is reasonable + assert len(dumped) < 25 * len(big_schema_sdl) + loaded = pickle.loads(dumped) - # check that pickling again creates the same result - dumped = pickle.dumps(schema) - assert len(dumped) < 25 * len(big_schema_sdl) - loaded = pickle.loads(dumped) - assert print_schema(loaded) == big_schema_sdl + # check that printing the unpickled schema gives the same SDL + assert print_schema(loaded) == big_schema_sdl + + # check that pickling again creates the same result + dumped = pickle.dumps(schema) + assert len(dumped) < 25 * len(big_schema_sdl) + loaded = pickle.loads(dumped) + assert print_schema(loaded) == big_schema_sdl + + finally: + sys.setrecursionlimit(limit) @mark.timeout(60 * timeout_factor) def can_deep_copy_pickled_big_schema(big_schema_sdl): # noqa: F811 # use our printing conventions big_schema_sdl = cycle_sdl(big_schema_sdl) - # create a schema from the big SDL - schema = build_schema(big_schema_sdl, assume_valid_sdl=True) - # pickle and unpickle the schema - loaded = pickle.loads(pickle.dumps(schema)) - # create a deepcopy of the unpickled schema - copied = deepcopy(loaded) - # check that printing the copied schema gives the same SDL - assert print_schema(copied) == big_schema_sdl + limit = sys.getrecursionlimit() + sys.setrecursionlimit(max(limit, 4000)) # needed for pickle + + try: + # create a schema from the big SDL + schema = build_schema(big_schema_sdl, assume_valid_sdl=True) + # pickle and unpickle the schema + loaded = pickle.loads(pickle.dumps(schema)) + # create a deepcopy of the unpickled schema + copied = deepcopy(loaded) + # check that printing the copied schema gives the same SDL + assert print_schema(copied) == big_schema_sdl + + finally: + sys.setrecursionlimit(limit) diff --git a/tests/utilities/test_introspection_from_schema.py b/tests/utilities/test_introspection_from_schema.py index 96ec968f..878ac0fb 100644 --- a/tests/utilities/test_introspection_from_schema.py +++ b/tests/utilities/test_introspection_from_schema.py @@ -1,12 +1,20 @@ +import pickle +import sys +from copy import deepcopy + +from pytest import mark + from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString from graphql.utilities import ( IntrospectionQuery, build_client_schema, + build_schema, introspection_from_schema, print_schema, ) -from ..utils import dedent +from ..fixtures import big_schema_introspection_result, big_schema_sdl # noqa: F401 +from ..utils import dedent, timeout_factor def introspection_to_sdl(introspection: IntrospectionQuery) -> str: @@ -60,3 +68,109 @@ def converts_a_simple_schema_without_description(): } """ ) + + def describe_deepcopy_and_pickle(): # pragma: no cover + # introspect the schema + introspected_schema = introspection_from_schema(schema) + introspection_size = len(str(introspected_schema)) + + def can_deep_copy_schema(): + # create a deepcopy of the schema + copied = deepcopy(schema) + # check that introspecting the copied schema gives the same result + assert introspection_from_schema(copied) == introspected_schema + + def can_pickle_and_unpickle_schema(): + # check that the schema can be pickled + # (particularly, there should be no recursion error, + # or errors because of trying to pickle lambdas or local functions) + dumped = pickle.dumps(schema) + + # check that the pickle size is reasonable + assert len(dumped) < 5 * introspection_size + loaded = pickle.loads(dumped) + + # check that introspecting the unpickled schema gives the same result + assert introspection_from_schema(loaded) == introspected_schema + + # check that pickling again creates the same result + dumped = pickle.dumps(schema) + assert len(dumped) < 5 * introspection_size + loaded = pickle.loads(dumped) + assert introspection_from_schema(loaded) == introspected_schema + + def can_deep_copy_pickled_schema(): + # pickle and unpickle the schema + loaded = pickle.loads(pickle.dumps(schema)) + # create a deepcopy of the unpickled schema + copied = deepcopy(loaded) + # check that introspecting the copied schema gives the same result + assert introspection_from_schema(copied) == introspected_schema + + @mark.slow + def describe_deepcopy_and_pickle_big(): # pragma: no cover + @mark.timeout(20 * timeout_factor) + def can_deep_copy_big_schema(big_schema_sdl): # noqa: F811 + # introspect the original big schema + big_schema = build_schema(big_schema_sdl) + expected_introspection = introspection_from_schema(big_schema) + + # create a deepcopy of the schema + copied = deepcopy(big_schema) + # check that introspecting the copied schema gives the same result + assert introspection_from_schema(copied) == expected_introspection + + @mark.timeout(60 * timeout_factor) + def can_pickle_and_unpickle_big_schema(big_schema_sdl): # noqa: F811 + # introspect the original big schema + big_schema = build_schema(big_schema_sdl) + expected_introspection = introspection_from_schema(big_schema) + size_introspection = len(str(expected_introspection)) + + limit = sys.getrecursionlimit() + sys.setrecursionlimit(max(limit, 4000)) # needed for pickle + + try: + # check that the schema can be pickled + # (particularly, there should be no recursion error, + # or errors because of trying to pickle lambdas or local functions) + dumped = pickle.dumps(big_schema) + + # check that the pickle size is reasonable + assert len(dumped) < 5 * size_introspection + loaded = pickle.loads(dumped) + + # check that introspecting the pickled schema gives the same result + assert introspection_from_schema(loaded) == expected_introspection + + # check that pickling again creates the same result + dumped = pickle.dumps(loaded) + assert len(dumped) < 5 * size_introspection + loaded = pickle.loads(dumped) + + # check that introspecting the re-pickled schema gives the same result + assert introspection_from_schema(loaded) == expected_introspection + + finally: + sys.setrecursionlimit(limit) + + @mark.timeout(60 * timeout_factor) + def can_deep_copy_pickled_big_schema(big_schema_sdl): # noqa: F811 + # introspect the original big schema + big_schema = build_schema(big_schema_sdl) + expected_introspection = introspection_from_schema(big_schema) + + limit = sys.getrecursionlimit() + sys.setrecursionlimit(max(limit, 4000)) # needed for pickle + + try: + # pickle and unpickle the schema + loaded = pickle.loads(pickle.dumps(big_schema)) + # create a deepcopy of the unpickled schema + copied = deepcopy(loaded) + + # check that introspecting the copied schema gives the same result + assert introspection_from_schema(copied) == expected_introspection + + finally: + sys.setrecursionlimit(limit) diff --git a/tox.ini b/tox.ini index 3029f3b4..8258d6e2 100644 --- a/tox.ini +++ b/tox.ini @@ -53,10 +53,12 @@ commands = deps = pytest>=7.1,<8 pytest-asyncio>=0.20,<1 - pytest-benchmark>=3.4,<4 + pytest-benchmark>=4,<5 pytest-cov>=4,<5 pytest-describe>=2,<3 pytest-timeout>=2,<3 py37: typing-extensions>=4.3,<5 commands = + # to also run the time-consuming tests: tox -e py310 -- --run-slow + # to run the benchmarks: tox -e py310 -- -k benchmarks --benchmark-enable pytest tests {posargs: --cov-report=term-missing --cov=graphql --cov=tests --cov-fail-under=100}