diff --git a/src/utilities/extendSchema.ts b/src/utilities/extendSchema.ts index 15b1983413..d5ccdf3a6b 100644 --- a/src/utilities/extendSchema.ts +++ b/src/utilities/extendSchema.ts @@ -1,3 +1,4 @@ +import { AccumulatorMap } from '../jsutils/AccumulatorMap'; import { inspect } from '../jsutils/inspect'; import { invariant } from '../jsutils/invariant'; import { keyMap } from '../jsutils/keyMap'; @@ -29,10 +30,6 @@ import type { UnionTypeExtensionNode, } from '../language/ast'; import { Kind } from '../language/kinds'; -import { - isTypeDefinitionNode, - isTypeExtensionNode, -} from '../language/predicates'; import type { GraphQLArgumentConfig, @@ -131,7 +128,25 @@ export function extendSchemaImpl( ): GraphQLSchemaNormalizedConfig { // Collect the type definitions and extensions found in the document. const typeDefs: Array = []; - const typeExtensionsMap = Object.create(null); + + const scalarExtensions = new AccumulatorMap< + string, + ScalarTypeExtensionNode + >(); + const objectExtensions = new AccumulatorMap< + string, + ObjectTypeExtensionNode + >(); + const interfaceExtensions = new AccumulatorMap< + string, + InterfaceTypeExtensionNode + >(); + const unionExtensions = new AccumulatorMap(); + const enumExtensions = new AccumulatorMap(); + const inputObjectExtensions = new AccumulatorMap< + string, + InputObjectTypeExtensionNode + >(); // New directives and types are separate because a directives and types can // have the same name. For example, a type named "skip". @@ -141,33 +156,57 @@ export function extendSchemaImpl( // Schema extensions are collected which may add additional operation types. const schemaExtensions: Array = []; + let isSchemaChanged = false; for (const def of documentAST.definitions) { - if (def.kind === Kind.SCHEMA_DEFINITION) { - schemaDef = def; - } else if (def.kind === Kind.SCHEMA_EXTENSION) { - schemaExtensions.push(def); - } else if (isTypeDefinitionNode(def)) { - typeDefs.push(def); - } else if (isTypeExtensionNode(def)) { - const extendedTypeName = def.name.value; - const existingTypeExtensions = typeExtensionsMap[extendedTypeName]; - typeExtensionsMap[extendedTypeName] = existingTypeExtensions - ? existingTypeExtensions.concat([def]) - : [def]; - } else if (def.kind === Kind.DIRECTIVE_DEFINITION) { - directiveDefs.push(def); + switch (def.kind) { + case Kind.SCHEMA_DEFINITION: + schemaDef = def; + break; + case Kind.SCHEMA_EXTENSION: + schemaExtensions.push(def); + break; + case Kind.DIRECTIVE_DEFINITION: + directiveDefs.push(def); + break; + + // Type Definitions + case Kind.SCALAR_TYPE_DEFINITION: + case Kind.OBJECT_TYPE_DEFINITION: + case Kind.INTERFACE_TYPE_DEFINITION: + case Kind.UNION_TYPE_DEFINITION: + case Kind.ENUM_TYPE_DEFINITION: + case Kind.INPUT_OBJECT_TYPE_DEFINITION: + typeDefs.push(def); + break; + + // Type System Extensions + case Kind.SCALAR_TYPE_EXTENSION: + scalarExtensions.add(def.name.value, def); + break; + case Kind.OBJECT_TYPE_EXTENSION: + objectExtensions.add(def.name.value, def); + break; + case Kind.INTERFACE_TYPE_EXTENSION: + interfaceExtensions.add(def.name.value, def); + break; + case Kind.UNION_TYPE_EXTENSION: + unionExtensions.add(def.name.value, def); + break; + case Kind.ENUM_TYPE_EXTENSION: + enumExtensions.add(def.name.value, def); + break; + case Kind.INPUT_OBJECT_TYPE_EXTENSION: + inputObjectExtensions.add(def.name.value, def); + break; + default: + continue; } + isSchemaChanged = true; } // If this document contains no new types, extensions, or directives then // return the same unmodified GraphQLSchema instance. - if ( - Object.keys(typeExtensionsMap).length === 0 && - typeDefs.length === 0 && - directiveDefs.length === 0 && - schemaExtensions.length === 0 && - schemaDef == null - ) { + if (!isSchemaChanged) { return schemaConfig; } @@ -275,7 +314,7 @@ export function extendSchemaImpl( type: GraphQLInputObjectType, ): GraphQLInputObjectType { const config = type.toConfig(); - const extensions = typeExtensionsMap[config.name] ?? []; + const extensions = inputObjectExtensions.get(config.name) ?? []; return new GraphQLInputObjectType({ ...config, @@ -292,7 +331,7 @@ export function extendSchemaImpl( function extendEnumType(type: GraphQLEnumType): GraphQLEnumType { const config = type.toConfig(); - const extensions = typeExtensionsMap[type.name] ?? []; + const extensions = enumExtensions.get(type.name) ?? []; return new GraphQLEnumType({ ...config, @@ -306,7 +345,7 @@ export function extendSchemaImpl( function extendScalarType(type: GraphQLScalarType): GraphQLScalarType { const config = type.toConfig(); - const extensions = typeExtensionsMap[config.name] ?? []; + const extensions = scalarExtensions.get(config.name) ?? []; let specifiedByURL = config.specifiedByURL; for (const extensionNode of extensions) { @@ -322,7 +361,7 @@ export function extendSchemaImpl( function extendObjectType(type: GraphQLObjectType): GraphQLObjectType { const config = type.toConfig(); - const extensions = typeExtensionsMap[config.name] ?? []; + const extensions = objectExtensions.get(config.name) ?? []; return new GraphQLObjectType({ ...config, @@ -342,7 +381,7 @@ export function extendSchemaImpl( type: GraphQLInterfaceType, ): GraphQLInterfaceType { const config = type.toConfig(); - const extensions = typeExtensionsMap[config.name] ?? []; + const extensions = interfaceExtensions.get(config.name) ?? []; return new GraphQLInterfaceType({ ...config, @@ -360,7 +399,7 @@ export function extendSchemaImpl( function extendUnionType(type: GraphQLUnionType): GraphQLUnionType { const config = type.toConfig(); - const extensions = typeExtensionsMap[config.name] ?? []; + const extensions = unionExtensions.get(config.name) ?? []; return new GraphQLUnionType({ ...config, @@ -579,10 +618,10 @@ export function extendSchemaImpl( function buildType(astNode: TypeDefinitionNode): GraphQLNamedType { const name = astNode.name.value; - const extensionASTNodes = typeExtensionsMap[name] ?? []; switch (astNode.kind) { case Kind.OBJECT_TYPE_DEFINITION: { + const extensionASTNodes = objectExtensions.get(name) ?? []; const allNodes = [astNode, ...extensionASTNodes]; return new GraphQLObjectType({ @@ -595,6 +634,7 @@ export function extendSchemaImpl( }); } case Kind.INTERFACE_TYPE_DEFINITION: { + const extensionASTNodes = interfaceExtensions.get(name) ?? []; const allNodes = [astNode, ...extensionASTNodes]; return new GraphQLInterfaceType({ @@ -607,6 +647,7 @@ export function extendSchemaImpl( }); } case Kind.ENUM_TYPE_DEFINITION: { + const extensionASTNodes = enumExtensions.get(name) ?? []; const allNodes = [astNode, ...extensionASTNodes]; return new GraphQLEnumType({ @@ -618,6 +659,7 @@ export function extendSchemaImpl( }); } case Kind.UNION_TYPE_DEFINITION: { + const extensionASTNodes = unionExtensions.get(name) ?? []; const allNodes = [astNode, ...extensionASTNodes]; return new GraphQLUnionType({ @@ -629,6 +671,7 @@ export function extendSchemaImpl( }); } case Kind.SCALAR_TYPE_DEFINITION: { + const extensionASTNodes = scalarExtensions.get(name) ?? []; return new GraphQLScalarType({ name, description: astNode.description?.value, @@ -638,6 +681,7 @@ export function extendSchemaImpl( }); } case Kind.INPUT_OBJECT_TYPE_DEFINITION: { + const extensionASTNodes = inputObjectExtensions.get(name) ?? []; const allNodes = [astNode, ...extensionASTNodes]; return new GraphQLInputObjectType({