From 5d69b713fa88e171c23fc39a598820d94453d465 Mon Sep 17 00:00:00 2001 From: Andreas Berger Date: Thu, 20 May 2021 09:19:19 +0200 Subject: [PATCH] Refactor schema augmentation to have hooks for customization (#224) Due to these changes, schema augmentation is no longer performed on the compiled graphql schema, but on the given type definitions. This will greatly simplify the extension of the neo4j schema with custom resolvers, such as with https://netflix.github.io/dgs/. --- .../org/neo4j/graphql/AugmentationHandler.kt | 324 +++++++++++++++++- .../kotlin/org/neo4j/graphql/BuildingEnv.kt | 280 --------------- .../org/neo4j/graphql/ExtensionFunctions.kt | 17 +- .../org/neo4j/graphql/GraphQLExtensions.kt | 163 ++++----- .../kotlin/org/neo4j/graphql/Predicates.kt | 44 +-- .../org/neo4j/graphql/RelationDirection.kt | 14 - .../org/neo4j/graphql/RelationshipInfo.kt | 80 +++++ .../kotlin/org/neo4j/graphql/SchemaBuilder.kt | 323 ++++++++--------- .../kotlin/org/neo4j/graphql/Translator.kt | 8 +- .../graphql/handler/AugmentFieldHandler.kt | 71 ++++ .../neo4j/graphql/handler/BaseDataFetcher.kt | 27 +- .../handler/BaseDataFetcherForContainer.kt | 12 +- .../graphql/handler/CreateTypeHandler.kt | 48 +-- .../graphql/handler/CypherDirectiveHandler.kt | 31 +- .../neo4j/graphql/handler/DeleteHandler.kt | 56 +-- .../graphql/handler/MergeOrUpdateHandler.kt | 64 ++-- .../org/neo4j/graphql/handler/QueryHandler.kt | 74 ++-- .../handler/projection/ProjectionBase.kt | 17 +- .../handler/relation/BaseRelationHandler.kt | 160 +++++---- .../handler/relation/CreateRelationHandler.kt | 53 ++- .../relation/CreateRelationTypeHandler.kt | 113 ++++-- .../handler/relation/DeleteRelationHandler.kt | 40 +-- .../graphql/utils/GraphQLSchemaTestSuite.kt | 4 + .../test/resources/augmentation-tests.adoc | 4 +- .../config/Neo4jConfiguration.kt | 1 - 25 files changed, 1128 insertions(+), 900 deletions(-) delete mode 100644 core/src/main/kotlin/org/neo4j/graphql/BuildingEnv.kt delete mode 100644 core/src/main/kotlin/org/neo4j/graphql/RelationDirection.kt create mode 100644 core/src/main/kotlin/org/neo4j/graphql/RelationshipInfo.kt create mode 100644 core/src/main/kotlin/org/neo4j/graphql/handler/AugmentFieldHandler.kt diff --git a/core/src/main/kotlin/org/neo4j/graphql/AugmentationHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/AugmentationHandler.kt index 08afd4f2..f875b2c7 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/AugmentationHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/AugmentationHandler.kt @@ -1,17 +1,331 @@ package org.neo4j.graphql +import graphql.language.* +import graphql.language.TypeDefinition import graphql.schema.DataFetcher -import graphql.schema.GraphQLFieldDefinition -import graphql.schema.GraphQLFieldsContainer +import graphql.schema.idl.ScalarInfo +import graphql.schema.idl.TypeDefinitionRegistry +import org.atteo.evo.inflector.English +import org.neo4j.graphql.handler.projection.ProjectionBase -abstract class AugmentationHandler(val schemaConfig: SchemaConfig) { +/** + * A base class for augmenting a TypeDefinitionRegistry. There a re 2 steps in augmenting the types: + * 1. augmenting the type by creating the relevant query / mutation fields and adding filtering and sorting to the relation fields + * 2. generating a data fetcher based on a field definition. The field may be an augmented field (from step 1) + * but can also be a user specified query / mutation field + */ +abstract class AugmentationHandler( + val schemaConfig: SchemaConfig, + val typeDefinitionRegistry: TypeDefinitionRegistry, + val neo4jTypeDefinitionRegistry: TypeDefinitionRegistry +) { enum class OperationType { QUERY, MUTATION } - open fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) {} + /** + * The 1st step in enhancing a schema. This method creates relevant query / mutation fields and / or adds filtering and sorting to the relation fields of the given type + * @param type the type for which the schema should be enhanced / augmented + */ + open fun augmentType(type: ImplementingTypeDefinition<*>) {} - abstract fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? + /** + * The 2nd step is creating a data fetcher based on a field definition. The field may be an augmented field (from step 1) + * but can also be a user specified query / mutation field + * @param operationType the type of the field + * @param fieldDefinition the filed to create the data fetcher for + * @return a data fetcher for the field or null if not applicable + */ + abstract fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? + protected fun buildFieldDefinition( + prefix: String, + resultType: ImplementingTypeDefinition<*>, + scalarFields: List, + nullableResult: Boolean, + forceOptionalProvider: (field: FieldDefinition) -> Boolean = { false } + ): FieldDefinition.Builder { + var type: Type<*> = TypeName(resultType.name) + if (!nullableResult) { + type = NonNullType(type) + } + return FieldDefinition.newFieldDefinition() + .name("$prefix${resultType.name}") + .inputValueDefinitions(getInputValueDefinitions(scalarFields, forceOptionalProvider)) + .type(type) + } + + protected fun getInputValueDefinitions( + relevantFields: List, + forceOptionalProvider: (field: FieldDefinition) -> Boolean): List { + return relevantFields.map { field -> + var type = getInputType(field.type) + type = if (forceOptionalProvider(field)) { + (type as? NonNullType)?.type ?: type + } else { + type + } + input(field.name, type) + } + } + + protected fun addQueryField(fieldDefinition: FieldDefinition) { + addOperation(typeDefinitionRegistry.queryTypeName(), fieldDefinition) + } + + protected fun addMutationField(fieldDefinition: FieldDefinition) { + addOperation(typeDefinitionRegistry.mutationTypeName(), fieldDefinition) + } + + protected fun isRootType(type: ImplementingTypeDefinition<*>): Boolean { + return type.name == typeDefinitionRegistry.queryTypeName() + || type.name == typeDefinitionRegistry.mutationTypeName() + || type.name == typeDefinitionRegistry.subscriptionTypeName() + } + + /** + * add the given operation to the corresponding rootType + */ + private fun addOperation(rootTypeName: String, fieldDefinition: FieldDefinition) { + val rootType = typeDefinitionRegistry.getType(rootTypeName)?.unwrap() + if (rootType == null) { + typeDefinitionRegistry.add(ObjectTypeDefinition.newObjectTypeDefinition() + .name(rootTypeName) + .fieldDefinition(fieldDefinition) + .build()) + } else { + val existingRootType = (rootType as? ObjectTypeDefinition + ?: throw IllegalStateException("root type $rootTypeName is not an object type but ${rootType.javaClass}")) + if (existingRootType.fieldDefinitions.find { it.name == fieldDefinition.name } != null) { + return // definition already exists, we don't override it + } + typeDefinitionRegistry.remove(rootType) + typeDefinitionRegistry.add(rootType.transform { it.fieldDefinition(fieldDefinition) }) + } + } + + protected fun addFilterType(type: ImplementingTypeDefinition<*>, createdTypes: MutableSet = mutableSetOf()): String { + val filterName = if (schemaConfig.useWhereFilter) type.name + "Where" else "_${type.name}Filter" + if (createdTypes.contains(filterName)) { + return filterName + } + val existingFilterType = typeDefinitionRegistry.getType(filterName).unwrap() + if (existingFilterType != null) { + return (existingFilterType as? InputObjectTypeDefinition)?.name + ?: throw IllegalStateException("Filter type $filterName is already defined but not an input type") + } + createdTypes.add(filterName) + val builder = InputObjectTypeDefinition.newInputObjectDefinition() + .name(filterName) + listOf("AND", "OR", "NOT").forEach { + builder.inputValueDefinition(InputValueDefinition.newInputValueDefinition() + .name(it) + .type(ListType(NonNullType(TypeName(filterName)))) + .build()) + } + type.fieldDefinitions + .filter { it.dynamicPrefix() == null } // TODO currently we do not support filtering on dynamic properties + .forEach { field -> + val typeDefinition = field.type.resolve() + ?: throw IllegalArgumentException("type ${field.type.name()} cannot be resolved") + val filterType = when { + field.type.inner().isNeo4jType() -> getInputType(field.type).name()!! + typeDefinition is ScalarTypeDefinition -> typeDefinition.name + typeDefinition is EnumTypeDefinition -> typeDefinition.name + typeDefinition is ImplementingTypeDefinition -> { + when { + field.type.inner().isNeo4jType() -> typeDefinition.name + else -> addFilterType(typeDefinition, createdTypes) + } + } + else -> throw IllegalArgumentException("${field.type.name()} is neither an object nor an interface") + } + + if (field.isRelationship()) { + RelationOperator.createRelationFilterFields(type, field, filterType, builder) + } else { + FieldOperator.forType(typeDefinition, field.type.inner().isNeo4jType()) + .forEach { op -> builder.addFilterField(op.fieldName(field.name), op.list, filterType, field.description) } + if (typeDefinition.isNeo4jSpatialType() == true) { + val distanceFilterType = getSpatialDistanceFilter(neo4jTypeDefinitionRegistry.getUnwrappedType(filterType) as TypeDefinition<*>) + FieldOperator.forType(distanceFilterType, true) + .forEach { op -> builder.addFilterField(op.fieldName(field.name + NEO4j_POINT_DISTANCE_FILTER_SUFFIX), op.list, NEO4j_POINT_DISTANCE_FILTER) } + } + } + + } + typeDefinitionRegistry.add(builder.build()) + return filterName + } + + private fun getSpatialDistanceFilter(pointType: TypeDefinition<*>): InputObjectTypeDefinition { + return addInputType(NEO4j_POINT_DISTANCE_FILTER, listOf( + input("distance", NonNullType(TypeFloat)), + input("point", NonNullType(TypeName(pointType.name))) + )) + } + + protected fun addOptions(type: ImplementingTypeDefinition<*>): String { + val optionsName = "${type.name}Options" + val optionsType = typeDefinitionRegistry.getType(optionsName)?.unwrap() + if (optionsType != null) { + return (optionsType as? InputObjectTypeDefinition)?.name + ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") + } + val sortTypeName = addSortInputType(type) + val optionsTypeBuilder = InputObjectTypeDefinition.newInputObjectDefinition().name(optionsName) + if (sortTypeName != null) { + optionsTypeBuilder.inputValueDefinition(input( + ProjectionBase.SORT, + ListType(NonNullType(TypeName(sortTypeName))), + "Specify one or more $sortTypeName objects to sort ${English.plural(type.name)} by. The sorts will be applied in the order in which they are arranged in the array.") + ) + } + optionsTypeBuilder + .inputValueDefinition(input(ProjectionBase.LIMIT, TypeInt, "Defines the maximum amount of records returned")) + .inputValueDefinition(input(ProjectionBase.SKIP, TypeInt, "Defines the amount of records to be skipped")) + .build() + typeDefinitionRegistry.add(optionsTypeBuilder.build()) + return optionsName + } + + private fun addSortInputType(type: ImplementingTypeDefinition<*>): String? { + val sortTypeName = "${type.name}Sort" + val sortType = typeDefinitionRegistry.getType(sortTypeName)?.unwrap() + if (sortType != null) { + return (sortType as? InputObjectTypeDefinition)?.name + ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") + } + val relevantFields = type.getScalarFields() + if (relevantFields.isEmpty()) { + return null + } + val builder = InputObjectTypeDefinition.newInputObjectDefinition() + .name(sortTypeName) + .description("Fields to sort ${type.name}s by. The order in which sorts are applied is not guaranteed when specifying many fields in one MovieSort object.".asDescription()) + for (relevantField in relevantFields) { + builder.inputValueDefinition(input(relevantField.name, TypeName("SortDirection"))) + } + typeDefinitionRegistry.add(builder.build()) + return sortTypeName + } + + protected fun addOrdering(type: ImplementingTypeDefinition<*>): String? { + val orderingName = "_${type.name}Ordering" + var existingOrderingType = typeDefinitionRegistry.getType(orderingName)?.unwrap() + if (existingOrderingType != null) { + return (existingOrderingType as? EnumTypeDefinition)?.name + ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") + } + val sortingFields = type.getScalarFields() + if (sortingFields.isEmpty()) { + return null + } + existingOrderingType = EnumTypeDefinition.newEnumTypeDefinition() + .name(orderingName) + .enumValueDefinitions(sortingFields.flatMap { fd -> + listOf("_asc", "_desc") + .map { + EnumValueDefinition + .newEnumValueDefinition() + .name(fd.name + it) + .build() + } + }) + .build() + typeDefinitionRegistry.add(existingOrderingType) + return orderingName + } + + private fun addInputType(inputName: String, relevantFields: List): InputObjectTypeDefinition { + var inputType = typeDefinitionRegistry.getType(inputName)?.unwrap() + if (inputType != null) { + return inputType as? InputObjectTypeDefinition + ?: throw IllegalStateException("Filter type $inputName is already defined but not an input type") + } + inputType = getInputType(inputName, relevantFields) + typeDefinitionRegistry.add(inputType) + return inputType + } + + private fun getInputType(inputName: String, relevantFields: List): InputObjectTypeDefinition { + return InputObjectTypeDefinition.newInputObjectDefinition() + .name(inputName) + .inputValueDefinitions(relevantFields) + .build() + } + + private fun getInputType(type: Type<*>): Type<*> { + if (type.inner().isNeo4jType()) { + return neo4jTypeDefinitions + .find { it.typeDefinition == type.name() } + ?.let { TypeName(it.inputDefinition) } + ?: throw IllegalArgumentException("Cannot find input type for ${type.name()}") + } + return type + } + + private fun getTypeFromAnyRegistry(name: String?): TypeDefinition<*>? = typeDefinitionRegistry.getUnwrappedType(name) + ?: neo4jTypeDefinitionRegistry.getUnwrappedType(name) + + fun ImplementingTypeDefinition<*>.relationship(): RelationshipInfo>? = RelationshipInfo.create(this, neo4jTypeDefinitionRegistry) + + fun ImplementingTypeDefinition<*>.getScalarFields(): List = fieldDefinitions + .filter { it.type.inner().isScalar() || it.type.inner().isNeo4jType() } + .sortedByDescending { it.type.inner().isID() } + + fun ImplementingTypeDefinition<*>.getFieldDefinition(name: String) = this.fieldDefinitions.find { it.name == name } + fun ImplementingTypeDefinition<*>.getIdField() = this.fieldDefinitions.find { it.type.inner().isID() } + + fun Type<*>.resolve(): TypeDefinition<*>? = getTypeFromAnyRegistry(name()) + fun Type<*>.isScalar(): Boolean = resolve() is ScalarTypeDefinition + fun Type<*>.isNeo4jType(): Boolean = name() + ?.takeIf { + !ScalarInfo.GRAPHQL_SPECIFICATION_SCALARS_DEFINITIONS.containsKey(it) + && it.startsWith("_Neo4j") // TODO remove this check by refactoring neo4j input types + } + ?.let { neo4jTypeDefinitionRegistry.getUnwrappedType(it) } != null + + fun Type<*>.isID(): Boolean = name() == "ID" + + + fun FieldDefinition.isNativeId(): Boolean = name == ProjectionBase.NATIVE_ID + fun FieldDefinition.dynamicPrefix(): String? = + getDirectiveArgument(DirectiveConstants.DYNAMIC, DirectiveConstants.DYNAMIC_PREFIX, null) + fun FieldDefinition.isRelationship(): Boolean = + !type.inner().isNeo4jType() && type.resolve() is ImplementingTypeDefinition<*> + + + fun TypeDefinitionRegistry.getUnwrappedType(name: String?): TypeDefinition>? = getType(name)?.unwrap() + + fun DirectivesContainer<*>.cypherDirective(): CypherDirective? = if (hasDirective(DirectiveConstants.CYPHER)) { + CypherDirective( + getMandatoryDirectiveArgument(DirectiveConstants.CYPHER, DirectiveConstants.CYPHER_STATEMENT), + getMandatoryDirectiveArgument(DirectiveConstants.CYPHER, DirectiveConstants.CYPHER_PASS_THROUGH, false) + ) + } else { + null + } + + fun DirectivesContainer<*>.hasDirective(name: String): Boolean = getDirective(name) != null + + fun DirectivesContainer<*>.getDirectiveArgument(directiveName: String, argumentName: String, defaultValue: T? = null): T? = + getDirectiveArgument(neo4jTypeDefinitionRegistry, directiveName, argumentName, defaultValue) + + private fun DirectivesContainer<*>.getMandatoryDirectiveArgument(directiveName: String, argumentName: String, defaultValue: T? = null): T = + getDirectiveArgument(directiveName, argumentName, defaultValue) + ?: throw IllegalStateException("No default value for @${directiveName}::$argumentName") + + fun input(name: String, type: Type<*>, description: String? = null): InputValueDefinition { + val input = InputValueDefinition + .newInputValueDefinition() + .name(name) + .type(type) + if (description != null) { + input.description(description.asDescription()) + } + return input + .build() + } } diff --git a/core/src/main/kotlin/org/neo4j/graphql/BuildingEnv.kt b/core/src/main/kotlin/org/neo4j/graphql/BuildingEnv.kt deleted file mode 100644 index e978a515..00000000 --- a/core/src/main/kotlin/org/neo4j/graphql/BuildingEnv.kt +++ /dev/null @@ -1,280 +0,0 @@ -package org.neo4j.graphql - -import graphql.Scalars -import graphql.schema.* -import org.atteo.evo.inflector.English -import org.neo4j.graphql.handler.projection.ProjectionBase - -class BuildingEnv( - val types: MutableMap, - private val sourceSchema: GraphQLSchema, - val schemaConfig: SchemaConfig -) { - - private val typesForRelation = types.values - .filterIsInstance() - .filter { it.getDirective(DirectiveConstants.RELATION) != null } - .map { it.getDirectiveArgument(DirectiveConstants.RELATION, DirectiveConstants.RELATION_NAME, null)!! to it.name } - .toMap() - - fun buildFieldDefinition( - prefix: String, - resultType: GraphQLOutputType, - scalarFields: List, - nullableResult: Boolean, - forceOptionalProvider: (field: GraphQLFieldDefinition) -> Boolean = { false } - ): GraphQLFieldDefinition.Builder { - var type: GraphQLOutputType = resultType - if (!nullableResult) { - type = GraphQLNonNull(type) - } - return GraphQLFieldDefinition.newFieldDefinition() - .name("$prefix${resultType.name()}") - .arguments(getInputValueDefinitions(scalarFields, forceOptionalProvider)) - .type(type.ref() as GraphQLOutputType) - } - - fun getInputValueDefinitions( - relevantFields: List, - forceOptionalProvider: (field: GraphQLFieldDefinition) -> Boolean): List { - return relevantFields.map { field -> - var type = field.type as GraphQLType - type = getInputType(type) - type = if (forceOptionalProvider(field)) { - (type as? GraphQLNonNull)?.wrappedType ?: type - } else { - type - } - input(field.name, type) - } - } - - fun addQueryField(fieldDefinition: GraphQLFieldDefinition) { - addOperation(sourceSchema.queryTypeName(), fieldDefinition) - } - - fun addMutationField(fieldDefinition: GraphQLFieldDefinition) { - addOperation(sourceSchema.mutationTypeName(), fieldDefinition) - } - - /** - * add the given operation to the corresponding rootType - */ - private fun addOperation(rootTypeName: String, fieldDefinition: GraphQLFieldDefinition) { - val rootType = types[rootTypeName] - types[rootTypeName] = if (rootType == null) { - val builder = GraphQLObjectType.newObject() - builder.name(rootTypeName) - .field(fieldDefinition) - .build() - } else { - val existingRootType = (rootType as? GraphQLObjectType - ?: throw IllegalStateException("root type $rootTypeName is not an object type but ${rootType.javaClass}")) - if (existingRootType.getFieldDefinition(fieldDefinition.name) != null) { - return // definition already exists, we don't override it - } - existingRootType - .transform { builder -> builder.field(fieldDefinition) } - } - } - - fun addFilterType(type: GraphQLFieldsContainer, createdTypes: MutableSet = mutableSetOf()): String { - val filterName = if (schemaConfig.useWhereFilter) type.name + "Where" else "_${type.name}Filter" - if (createdTypes.contains(filterName)) { - return filterName - } - val existingFilterType = types[filterName] - if (existingFilterType != null) { - return (existingFilterType as? GraphQLInputType)?.name() - ?: throw IllegalStateException("Filter type $filterName is already defined but not an input type") - } - createdTypes.add(filterName) - val builder = GraphQLInputObjectType.newInputObject() - .name(filterName) - listOf("AND", "OR", "NOT").forEach { - builder.field(GraphQLInputObjectField.newInputObjectField() - .name(it) - .type(GraphQLList(GraphQLNonNull(GraphQLTypeReference(filterName))))) - } - type.fieldDefinitions - .filter { it.dynamicPrefix() == null } // TODO currently we do not support filtering on dynamic properties - .forEach { field -> - val typeDefinition = field.type.inner() - val filterType = when { - typeDefinition.isNeo4jType() -> getInputType(typeDefinition).requiredName() - typeDefinition.isScalar() -> typeDefinition.innerName() - typeDefinition is GraphQLEnumType -> typeDefinition.innerName() - else -> addFilterType(getInnerFieldsContainer(typeDefinition), createdTypes) - } - - if (field.isRelationship()) { - RelationOperator.createRelationFilterFields(type, field, filterType, builder) - } else { - val graphQLType = types[filterType] ?: typeDefinition - FieldOperator.forType(graphQLType) - .forEach { op -> builder.addFilterField(op.fieldName(field.name), op.list, filterType, field.description) } - if (graphQLType.isNeo4jSpatialType()) { - val distanceFilterType = getSpatialDistanceFilter(graphQLType) - FieldOperator.forType(distanceFilterType) - .forEach { op -> builder.addFilterField(op.fieldName(field.name + NEO4j_POINT_DISTANCE_FILTER_SUFFIX), op.list, NEO4j_POINT_DISTANCE_FILTER) } - } - } - - } - types[filterName] = builder.build() - return filterName - } - - private fun getSpatialDistanceFilter(pointType: GraphQLType): GraphQLInputType { - return addInputType(NEO4j_POINT_DISTANCE_FILTER, listOf( - GraphQLFieldDefinition.newFieldDefinition().name("distance").type(GraphQLNonNull(Scalars.GraphQLFloat)).build(), - GraphQLFieldDefinition.newFieldDefinition().name("point").type(GraphQLNonNull(pointType)).build() - )) - } - - fun addOptions(type: GraphQLFieldsContainer): String { - val optionsName = "${type.name}Options" - val optionsType = types[optionsName] - if (optionsType != null) { - return (optionsType as? GraphQLInputType)?.requiredName() - ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") - } - val sortTypeName = addSortInputType(type) - val optionsTypeBuilder = GraphQLInputObjectType.newInputObject().name(optionsName) - if (sortTypeName != null) { - optionsTypeBuilder.field(GraphQLInputObjectField.newInputObjectField() - .name(ProjectionBase.SORT) - .type(GraphQLList(GraphQLNonNull(GraphQLTypeReference(sortTypeName)))) - .description("Specify one or more $sortTypeName objects to sort ${English.plural(type.name)} by. The sorts will be applied in the order in which they are arranged in the array.") - .build()) - } - optionsTypeBuilder.field(GraphQLInputObjectField.newInputObjectField() - .name(ProjectionBase.LIMIT) - .type(Scalars.GraphQLInt) - .description("Defines the maximum amount of records returned") - .build()) - .field(GraphQLInputObjectField.newInputObjectField() - .name(ProjectionBase.SKIP) - .type(Scalars.GraphQLInt) - .description("Defines the amount of records to be skipped") - .build()) - .build() - types[optionsName] = optionsTypeBuilder.build() - return optionsName - } - - private fun addSortInputType(type: GraphQLFieldsContainer): String? { - val sortTypeName = "${type.name}Sort" - val sortType = types[sortTypeName] - if (sortType != null) { - return (sortType as? GraphQLInputType)?.requiredName() - ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") - } - val relevantFields = type.relevantFields() - if (relevantFields.isEmpty()) { - return null - } - val builder = GraphQLInputObjectType.newInputObject() - .name(sortTypeName) - .description("Fields to sort ${type.name}s by. The order in which sorts are applied is not guaranteed when specifying many fields in one MovieSort object.") - for (relevantField in relevantFields) { - builder.field(GraphQLInputObjectField.newInputObjectField() - .name(relevantField.name) - .type(GraphQLTypeReference("SortDirection")) - .build()) - } - types[sortTypeName] = builder.build() - return sortTypeName - } - - fun addOrdering(type: GraphQLFieldsContainer): String? { - val orderingName = "_${type.name}Ordering" - var existingOrderingType = types[orderingName] - if (existingOrderingType != null) { - return (existingOrderingType as? GraphQLInputType)?.requiredName() - ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") - } - val sortingFields = type.fieldDefinitions - .filter { it.type.isScalar() || it.isNeo4jType() } - .sortedByDescending { it.isID() } - if (sortingFields.isEmpty()) { - return null - } - existingOrderingType = GraphQLEnumType.newEnum() - .name(orderingName) - .values(sortingFields.flatMap { fd -> - listOf("_asc", "_desc") - .map { - GraphQLEnumValueDefinition - .newEnumValueDefinition() - .name(fd.name + it) - .value(fd.name + it) - .build() - } - }) - .build() - types[orderingName] = existingOrderingType - return orderingName - } - - fun addInputType(inputName: String, relevantFields: List): GraphQLInputType { - var inputType = types[inputName] - if (inputType != null) { - return inputType as? GraphQLInputType - ?: throw IllegalStateException("Filter type $inputName is already defined but not an input type") - } - inputType = getInputType(inputName, relevantFields) - types[inputName] = inputType - return inputType - } - - fun getTypeForRelation(nameOfRelation: String): GraphQLObjectType? { - return typesForRelation[nameOfRelation]?.let { types[it] } as? GraphQLObjectType - } - - private fun getInputType(inputName: String, relevantFields: List): GraphQLInputObjectType { - return GraphQLInputObjectType.newInputObject() - .name(inputName) - .fields(getInputValueDefinitions(relevantFields)) - .build() - } - - private fun getInputValueDefinitions(relevantFields: List): List { - return relevantFields.map { - // just make evrything optional - val type = (it.type as? GraphQLNonNull)?.wrappedType ?: it.type - GraphQLInputObjectField - .newInputObjectField() - .name(it.name) - .description(it.description) - .type(getInputType(type).ref() as GraphQLInputType) - .build() - } - } - - private fun getInnerFieldsContainer(type: GraphQLType): GraphQLFieldsContainer { - var innerType = type.inner() - if (innerType is GraphQLTypeReference) { - innerType = types[innerType.name] - ?: throw IllegalArgumentException("${innerType.name} is unknown") - } - return innerType as? GraphQLFieldsContainer - ?: throw IllegalArgumentException("${innerType.name()} is neither an object nor an interface") - } - - private fun getInputType(type: GraphQLType): GraphQLInputType { - val inner = type.inner() - if (inner is GraphQLInputType) { - return type as GraphQLInputType - } - if (inner.isNeo4jType()) { - return neo4jTypeDefinitions - .find { it.typeDefinition == inner.name() } - ?.let { types[it.inputDefinition] } as? GraphQLInputType - ?: throw IllegalArgumentException("Cannot find input type for ${inner.name()}") - } - return type as? GraphQLInputType - ?: throw IllegalArgumentException("${type.name()} is not allowed for input") - } - -} diff --git a/core/src/main/kotlin/org/neo4j/graphql/ExtensionFunctions.kt b/core/src/main/kotlin/org/neo4j/graphql/ExtensionFunctions.kt index 720451b6..10514b70 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/ExtensionFunctions.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/ExtensionFunctions.kt @@ -1,23 +1,14 @@ package org.neo4j.graphql +import graphql.language.Description import graphql.language.VariableReference -import graphql.schema.GraphQLArgument -import graphql.schema.GraphQLInputType -import graphql.schema.GraphQLType import org.neo4j.cypherdsl.core.* +import java.util.* fun Iterable.joinNonEmpty(separator: CharSequence = ", ", prefix: CharSequence = "", postfix: CharSequence = "", limit: Int = -1, truncated: CharSequence = "...", transform: ((T) -> CharSequence)? = null): String { return if (iterator().hasNext()) joinTo(StringBuilder(), separator, prefix, postfix, limit, truncated, transform).toString() else "" } -fun input(name: String, type: GraphQLType): GraphQLArgument { - return GraphQLArgument - .newArgument() - .name(name) - .type((type.ref() as? GraphQLInputType) - ?: throw IllegalArgumentException("${type.innerName()} is not allowed for input")).build() -} - fun queryParameter(value: Any?, vararg parts: String?): Parameter { val name = when (value) { is VariableReference -> value.name @@ -36,3 +27,7 @@ fun PropertyContainer.id(): FunctionInvocation = when (this) { } fun String.toCamelCase(): String = Regex("[\\W_]([a-z])").replace(this) { it.groupValues[1].toUpperCase() } + +fun Optional.unwrap(): T? = orElse(null) + +fun String.asDescription() = Description(this, null, this.contains("\n")) diff --git a/core/src/main/kotlin/org/neo4j/graphql/GraphQLExtensions.kt b/core/src/main/kotlin/org/neo4j/graphql/GraphQLExtensions.kt index c7267d50..0b90c8c0 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/GraphQLExtensions.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/GraphQLExtensions.kt @@ -2,9 +2,9 @@ package org.neo4j.graphql import graphql.Scalars import graphql.language.* +import graphql.language.TypeDefinition import graphql.schema.* -import org.neo4j.cypherdsl.core.Node -import org.neo4j.cypherdsl.core.Relationship +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.SymbolicName import org.neo4j.graphql.DirectiveConstants.Companion.CYPHER import org.neo4j.graphql.DirectiveConstants.Companion.CYPHER_PASS_THROUGH @@ -13,21 +13,25 @@ import org.neo4j.graphql.DirectiveConstants.Companion.DYNAMIC import org.neo4j.graphql.DirectiveConstants.Companion.DYNAMIC_PREFIX import org.neo4j.graphql.DirectiveConstants.Companion.PROPERTY import org.neo4j.graphql.DirectiveConstants.Companion.PROPERTY_NAME -import org.neo4j.graphql.DirectiveConstants.Companion.RELATION_DIRECTION -import org.neo4j.graphql.DirectiveConstants.Companion.RELATION_FROM import org.neo4j.graphql.DirectiveConstants.Companion.RELATION_NAME import org.neo4j.graphql.DirectiveConstants.Companion.RELATION_TO import org.neo4j.graphql.handler.projection.ProjectionBase import java.math.BigDecimal import java.math.BigInteger -fun Type>.name(): String? = if (this.inner() is TypeName) (this.inner() as TypeName).name else null -fun Type>.inner(): Type> = when (this) { +fun Type<*>.name(): String? = if (this.inner() is TypeName) (this.inner() as TypeName).name else null +fun Type<*>.inner(): Type<*> = when (this) { is ListType -> this.type.inner() is NonNullType -> this.type.inner() else -> this } +fun Type<*>.isList(): Boolean = when (this) { + is ListType -> true + is NonNullType -> type.isList() + else -> false +} + fun GraphQLType.inner(): GraphQLType = when (this) { is GraphQLList -> this.wrappedType.inner() is GraphQLNonNull -> this.wrappedType.inner() @@ -38,19 +42,20 @@ fun GraphQLType.name(): String? = (this as? GraphQLNamedType)?.name fun GraphQLType.requiredName(): String = requireNotNull(name()) { "name is required but cannot be determined for " + this.javaClass } fun GraphQLType.isList() = this is GraphQLList || (this is GraphQLNonNull && this.wrappedType is GraphQLList) -fun GraphQLType.isScalar() = this.inner().let { it is GraphQLScalarType || it.innerName().startsWith("_Neo4j") } fun GraphQLType.isNeo4jType() = this.innerName().startsWith("_Neo4j") + fun GraphQLType.isNeo4jSpatialType() = this.innerName().startsWith("_Neo4jPoint") +fun TypeDefinition<*>.isNeo4jSpatialType() = this.name.startsWith("_Neo4jPoint") + fun GraphQLFieldDefinition.isNeo4jType(): Boolean = this.type.isNeo4jType() fun GraphQLFieldDefinition.isRelationship() = !type.isNeo4jType() && this.type.inner().let { it is GraphQLFieldsContainer } -fun GraphQLDirectiveContainer.isRelationType() = getDirective(DirectiveConstants.RELATION) != null fun GraphQLFieldsContainer.isRelationType() = (this as? GraphQLDirectiveContainer)?.getDirective(DirectiveConstants.RELATION) != null -fun GraphQLFieldsContainer.relationshipFor(name: String): RelationshipInfo? { +fun GraphQLFieldsContainer.relationshipFor(name: String): RelationshipInfo? { val field = getFieldDefinition(name) ?: throw IllegalArgumentException("$name is not defined on ${this.name}") - val fieldObjectType = field.type.inner() as? GraphQLFieldsContainer ?: return null + val fieldObjectType = field.type.inner() as? GraphQLImplementingType ?: return null val (relDirective, inverse) = if (isRelationType()) { val typeName = this.name @@ -67,7 +72,7 @@ fun GraphQLFieldsContainer.relationshipFor(name: String): RelationshipInfo? { ?: throw IllegalStateException("Field $field needs an @relation directive") } - val relInfo = relDetails(fieldObjectType, relDirective) + val relInfo = RelationshipInfo.create(fieldObjectType, relDirective) return if (inverse) relInfo.copy(direction = relInfo.direction.invert(), startField = relInfo.endField, endField = relInfo.startField) else relInfo } @@ -92,20 +97,7 @@ fun GraphQLFieldsContainer.label(): String = when { else -> name } - -fun GraphQLFieldsContainer.relevantFields() = fieldDefinitions - .filter { it.type.isScalar() || it.isNeo4jType() } - .sortedByDescending { it.isID() } - -fun GraphQLFieldsContainer.relationship(): RelationshipInfo? { - val relDirective = (this as? GraphQLDirectiveContainer)?.getDirective(DirectiveConstants.RELATION) ?: return null - val relType = relDirective.getArgument(RELATION_NAME, "")!! - val startField = relDirective.getMandatoryArgument(RELATION_FROM) - val endField = relDirective.getMandatoryArgument(RELATION_TO) - val direction = relDirective.getArgument(RELATION_DIRECTION)?.let { RelationDirection.valueOf(it) } - ?: RelationDirection.OUT - return RelationshipInfo(this, relType, direction, startField, endField) -} +fun GraphQLFieldsContainer.relationship(): RelationshipInfo? = RelationshipInfo.create(this) fun GraphQLType.ref(): GraphQLType = when (this) { is GraphQLNonNull -> GraphQLNonNull(this.wrappedType.ref()) @@ -116,63 +108,6 @@ fun GraphQLType.ref(): GraphQLType = when (this) { else -> GraphQLTypeReference(name()) } -fun relDetails(type: GraphQLFieldsContainer, relDirective: GraphQLDirective): RelationshipInfo { - val relType = relDirective.getArgument(RELATION_NAME, "")!! - val direction = relDirective.getArgument(RELATION_DIRECTION, null) - ?.let { RelationDirection.valueOf(it) } - ?: RelationDirection.OUT - - return RelationshipInfo(type, - relType, - direction, - relDirective.getMandatoryArgument(RELATION_FROM), - relDirective.getMandatoryArgument(RELATION_TO) - ) -} - -data class RelationshipInfo( - val type: GraphQLFieldsContainer, - val relType: String, - val direction: RelationDirection, - val startField: String, - val endField: String -) { - data class RelatedField( - val argumentName: String, - val field: GraphQLFieldDefinition, - val declaringType: GraphQLFieldsContainer - ) - - val typeName: String get() = this.type.name - - fun getStartFieldId() = getRelatedIdField(this.startField) - - fun getEndFieldId() = getRelatedIdField(this.endField) - - private fun getRelatedIdField(relFieldName: String?): RelatedField? { - if (relFieldName == null) return null - val relFieldDefinition = type.getFieldDefinition(relFieldName) - ?: throw IllegalArgumentException("field $relFieldName does not exists on ${type.innerName()}") - val relType = relFieldDefinition.type.inner() as? GraphQLFieldsContainer - ?: throw IllegalArgumentException("type ${relFieldDefinition.type.innerName()} not found") - return relType.fieldDefinitions.filter { it.isID() } - .map { - // TODO b/c we need to stay backwards kompatible this is not caml case but with underscore - //val filedName = normalizeName(relFieldName, it.name) - val filedName = "${relFieldName}_${it.name}" - RelatedField(filedName, it, relType) - } - .firstOrNull() - } - - fun createRelation(start: Node, end: Node): Relationship = - when (this.direction) { - RelationDirection.IN -> start.relationshipFrom(end, this.relType) - RelationDirection.OUT -> start.relationshipTo(end, this.relType) - RelationDirection.BOTH -> start.relationshipBetween(end, this.relType) - } -} - fun Field.aliasOrName(): String = (this.alias ?: this.name) fun Field.contextualize(variable: String) = variable + this.aliasOrName().capitalize() fun Field.contextualize(variable: SymbolicName) = variable.value + this.aliasOrName().capitalize() @@ -189,6 +124,23 @@ fun GraphQLType.getInnerFieldsContainer() = inner() as? GraphQLFieldsContainer fun GraphQLDirectiveContainer.getDirectiveArgument(directiveName: String, argumentName: String, defaultValue: T?): T? = getDirective(directiveName)?.getArgument(argumentName, defaultValue) ?: defaultValue +@Suppress("UNCHECKED_CAST") +fun DirectivesContainer<*>.getDirectiveArgument(typeRegistry: TypeDefinitionRegistry, directiveName: String, argumentName: String, defaultValue: T? = null): T? { + return (getDirective(directiveName) ?: return defaultValue) + .getArgument(argumentName)?.value?.toJavaValue() as T? + ?: typeRegistry.getDirectiveDefinition(directiveName) + ?.unwrap() + ?.inputValueDefinitions + ?.find { inputValueDefinition -> inputValueDefinition.name == argumentName } + ?.defaultValue?.toJavaValue() as T? + ?: defaultValue +} + +@Suppress("UNCHECKED_CAST") +fun DirectivesContainer<*>.getMandatoryDirectiveArgument(typeRegistry: TypeDefinitionRegistry, directiveName: String, argumentName: String, defaultValue: T? = null): T = + getDirectiveArgument(typeRegistry, directiveName, argumentName, defaultValue) + ?: throw IllegalStateException("No default value for @${directiveName}::$argumentName") + fun GraphQLDirective.getMandatoryArgument(argumentName: String, defaultValue: T? = null): T = this.getArgument(argumentName, defaultValue) ?: throw IllegalStateException(argumentName + " is required for @${this.name}") @@ -201,10 +153,12 @@ fun GraphQLDirective.getArgument(argumentName: String, defaultValue: T? = nu ?: throw IllegalStateException("No default value for @${this.name}::$argumentName") } -fun GraphQLFieldDefinition.cypherDirective() :CypherDirective?= getDirective(CYPHER)?.let { CypherDirective( - it.getMandatoryArgument(CYPHER_STATEMENT), - it.getMandatoryArgument(CYPHER_PASS_THROUGH, false) -) } +fun GraphQLFieldDefinition.cypherDirective(): CypherDirective? = getDirective(CYPHER)?.let { + CypherDirective( + it.getMandatoryArgument(CYPHER_STATEMENT), + it.getMandatoryArgument(CYPHER_PASS_THROUGH, false) + ) +} data class CypherDirective(val statement: String, val passThrough: Boolean) @@ -222,7 +176,7 @@ fun Value<*>.toJavaValue(): Any? = when (this) { is IntValue -> this.value.longValueExact() is VariableReference -> this is ArrayValue -> this.values.map { it.toJavaValue() }.toList() - is ObjectValue -> this.objectFields.map { it.name to it.value.toJavaValue() }.toMap() + is ObjectValue -> this.objectFields.associate { it.name to it.value.toJavaValue() } else -> throw IllegalStateException("Unhandled value $this") } @@ -230,22 +184,25 @@ fun GraphQLFieldDefinition.isID() = this.type.inner() == Scalars.GraphQLID fun GraphQLFieldDefinition.isNativeId() = this.name == ProjectionBase.NATIVE_ID fun GraphQLFieldsContainer.getIdField() = this.fieldDefinitions.find { it.isID() } -fun GraphQLInputObjectType.Builder.addFilterField(fieldName: String, isList: Boolean, filterType: String, description: String? = null) { - val wrappedType: GraphQLInputType = when { - isList -> GraphQLList(GraphQLTypeReference(filterType)) - else -> GraphQLTypeReference(filterType) +fun InputObjectTypeDefinition.Builder.addFilterField(fieldName: String, isList: Boolean, filterType: String, description: Description? = null) { + val wrappedType: Type<*> = when { + isList -> ListType(TypeName(filterType)) + else -> TypeName(filterType) } - val inputField = GraphQLInputObjectField.newInputObjectField() + val inputField = InputValueDefinition.newInputValueDefinition() .name(fieldName) - .description(description) .type(wrappedType) - this.field(inputField) -} + if (description != null) { + inputField.description(description) + } + this.inputValueDefinition(inputField.build()) +} -fun GraphQLSchema.queryTypeName() = this.queryType?.name ?: "Query" -fun GraphQLSchema.mutationTypeName() = this.mutationType?.name ?: "Mutation" -fun GraphQLSchema.subscriptionTypeName() = this.subscriptionType?.name ?: "Subscription" +fun TypeDefinitionRegistry.queryTypeName() = this.getOperationType("query") ?: "Query" +fun TypeDefinitionRegistry.mutationTypeName() = this.getOperationType("mutation") ?: "Mutation" +fun TypeDefinitionRegistry.subscriptionTypeName() = this.getOperationType("subscription") ?: "Subscription" +fun TypeDefinitionRegistry.getOperationType(name: String) = this.schemaDefinition().unwrap()?.operationTypeDefinitions?.firstOrNull { it.name == name }?.typeName?.name fun Any?.asGraphQLValue(): Value<*> = when (this) { null -> NullValue.newNullValue().build() @@ -261,3 +218,13 @@ fun Any?.asGraphQLValue(): Value<*> = when (this) { is String -> StringValue.newStringValue(this).build() else -> throw IllegalStateException("Cannot convert ${this.javaClass.name} into an graphql type") } + +fun DataFetchingEnvironment.typeAsContainer() = this.fieldDefinition.type.inner() as? GraphQLFieldsContainer + ?: throw IllegalStateException("expect type of field ${this.logField()} to be GraphQLFieldsContainer, but was ${this.fieldDefinition.type.name()}") + +fun DataFetchingEnvironment.logField() = "${this.parentType.name()}.${this.fieldDefinition.name}" + +val TypeInt = TypeName("Int") +val TypeFloat = TypeName("Float") +val TypeBoolean = TypeName("Boolean") +val TypeID = TypeName("ID") diff --git a/core/src/main/kotlin/org/neo4j/graphql/Predicates.kt b/core/src/main/kotlin/org/neo4j/graphql/Predicates.kt index d6d3cf8d..c0bff5e2 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/Predicates.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/Predicates.kt @@ -1,11 +1,13 @@ package org.neo4j.graphql -import graphql.Scalars -import graphql.language.NullValue -import graphql.language.ObjectValue -import graphql.language.Value -import graphql.schema.* -import org.neo4j.cypherdsl.core.* +import graphql.language.* +import graphql.language.TypeDefinition +import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLFieldsContainer +import org.neo4j.cypherdsl.core.Condition +import org.neo4j.cypherdsl.core.Expression +import org.neo4j.cypherdsl.core.Property +import org.neo4j.cypherdsl.core.PropertyContainer import org.slf4j.LoggerFactory typealias CypherDSL = org.neo4j.cypherdsl.core.Cypher @@ -35,13 +37,13 @@ enum class FieldOperator( C("_contains", "CONTAINS", { lhs, rhs -> lhs.contains(rhs) }), SW("_starts_with", "STARTS WITH", { lhs, rhs -> lhs.startsWith(rhs) }), EW("_ends_with", "ENDS WITH", { lhs, rhs -> lhs.endsWith(rhs) }), - MATCHES("_matches", "=~", {lhs, rhs -> lhs.matches(rhs) }), + MATCHES("_matches", "=~", { lhs, rhs -> lhs.matches(rhs) }), DISTANCE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX, "=", { lhs, rhs -> lhs.isEqualTo(rhs) }, distance = true), DISTANCE_LT(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_lt", "<", { lhs, rhs -> lhs.lt(rhs) }, distance = true), - DISTANCE_LTE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_lte", "<=", { lhs, rhs -> lhs.lte(rhs) }, distance = true), - DISTANCE_GT(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_gt", ">", { lhs, rhs -> lhs.gt(rhs) }, distance = true), + DISTANCE_LTE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_lte", "<=", { lhs, rhs -> lhs.lte(rhs) }, distance = true), + DISTANCE_GT(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_gt", ">", { lhs, rhs -> lhs.gt(rhs) }, distance = true), DISTANCE_GTE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_gte", ">=", { lhs, rhs -> lhs.gte(rhs) }, distance = true); val list = op == "IN" @@ -68,14 +70,14 @@ enum class FieldOperator( private fun resolveNeo4jTypeConditions(variablePrefix: String, queriedField: String, propertyContainer: PropertyContainer, field: GraphQLFieldDefinition, value: ObjectValue, suffix: String?): List { val neo4jTypeConverter = getNeo4jTypeConverter(field) val conditions = mutableListOf() - if (distance){ + if (distance) { val parameter = queryParameter(value, variablePrefix, queriedField, suffix) conditions += (neo4jTypeConverter as Neo4jPointConverter).createDistanceCondition( propertyContainer.property(field.propertyName()), parameter, conditionCreator ) - } else { + } else { value.objectFields.forEachIndexed { index, objectField -> val parameter = queryParameter(value, variablePrefix, queriedField, if (value.objectFields.size > 1) "And${index + 1}" else null, suffix, objectField.name) .withValue(objectField.value.toJavaValue()) @@ -108,18 +110,18 @@ enum class FieldOperator( } } - fun forType(type: GraphQLType): List = + fun forType(type: TypeDefinition<*>, isNeo4jType: Boolean): List = when { - type == Scalars.GraphQLBoolean -> listOf(EQ, NEQ) - type.innerName() == NEO4j_POINT_DISTANCE_FILTER -> listOf(EQ, LT, LTE, GT, GTE) + type.name == TypeBoolean.name -> listOf(EQ, NEQ) + type.name == NEO4j_POINT_DISTANCE_FILTER -> listOf(EQ, LT, LTE, GT, GTE) type.isNeo4jSpatialType() -> listOf(EQ, NEQ) - type.isNeo4jType() -> listOf(EQ, NEQ, IN, NIN) - type is GraphQLFieldsContainer || type is GraphQLInputObjectType -> throw IllegalArgumentException("This operators are not for relations, use the RelationOperator instead") - type is GraphQLEnumType -> listOf(EQ, NEQ, IN, NIN) + isNeo4jType -> listOf(EQ, NEQ, IN, NIN) + type is ImplementingTypeDefinition<*> -> throw IllegalArgumentException("This operators are not for relations, use the RelationOperator instead") + type is EnumTypeDefinition -> listOf(EQ, NEQ, IN, NIN) // todo list types - !type.isScalar() -> listOf(EQ, NEQ, IN, NIN) + type !is ScalarTypeDefinition -> listOf(EQ, NEQ, IN, NIN) else -> listOf(EQ, NEQ, IN, NIN, LT, LTE, GT, GTE) + - if (type.name() == "String" || type.name() == "ID") listOf(C, NC, SW, NSW, EW, NEW, MATCHES) else emptyList() + if (type.name == "String" || type.name == "ID") listOf(C, NC, SW, NSW, EW, NEW, MATCHES) else emptyList() } } @@ -183,11 +185,11 @@ enum class RelationOperator(val suffix: String, val op: String) { companion object { private val LOGGER = LoggerFactory.getLogger(RelationOperator::class.java) - fun createRelationFilterFields(type: GraphQLFieldsContainer, field: GraphQLFieldDefinition, filterType: String, builder: GraphQLInputObjectType.Builder) { + fun createRelationFilterFields(type: TypeDefinition<*>, field: FieldDefinition, filterType: String, builder: InputObjectTypeDefinition.Builder) { val list = field.type.isList() val addFilterField = { op: RelationOperator, description: String -> - builder.addFilterField(op.fieldName(field.name), false, filterType, description) + builder.addFilterField(op.fieldName(field.name), false, filterType, description.asDescription()) } addFilterField(EQ_OR_NOT_EXISTS, "Filters only those `${type.name}` for which ${if (list) "all" else "the"} `${field.name}`-relationship matches this filter. " + diff --git a/core/src/main/kotlin/org/neo4j/graphql/RelationDirection.kt b/core/src/main/kotlin/org/neo4j/graphql/RelationDirection.kt deleted file mode 100644 index d6a10edf..00000000 --- a/core/src/main/kotlin/org/neo4j/graphql/RelationDirection.kt +++ /dev/null @@ -1,14 +0,0 @@ -package org.neo4j.graphql - -enum class RelationDirection { - IN, - OUT, - BOTH; - - fun invert(): RelationDirection = when (this) { - IN -> OUT - OUT -> IN - else -> this - } - -} diff --git a/core/src/main/kotlin/org/neo4j/graphql/RelationshipInfo.kt b/core/src/main/kotlin/org/neo4j/graphql/RelationshipInfo.kt new file mode 100644 index 00000000..e597d90c --- /dev/null +++ b/core/src/main/kotlin/org/neo4j/graphql/RelationshipInfo.kt @@ -0,0 +1,80 @@ +package org.neo4j.graphql + +import graphql.language.ImplementingTypeDefinition +import graphql.schema.GraphQLDirective +import graphql.schema.GraphQLDirectiveContainer +import graphql.schema.GraphQLFieldsContainer +import graphql.schema.idl.TypeDefinitionRegistry +import org.neo4j.cypherdsl.core.Node +import org.neo4j.cypherdsl.core.Relationship +import org.neo4j.cypherdsl.core.SymbolicName + +data class RelationshipInfo( + val type: TYPE, + val typeName: String, + val relType: String, + val direction: RelationDirection, + val startField: String, + val endField: String +) { + + enum class RelationDirection { + IN, + OUT, + BOTH; + + fun invert(): RelationDirection = when (this) { + IN -> OUT + OUT -> IN + else -> this + } + + } + + companion object { + fun create(type: GraphQLFieldsContainer): RelationshipInfo? = (type as? GraphQLDirectiveContainer) + ?.getDirective(DirectiveConstants.RELATION) + ?.let { relDirective -> create(type, relDirective) } + + fun create(type: GraphQLFieldsContainer, relDirective: GraphQLDirective): RelationshipInfo { + val relType = relDirective.getArgument(DirectiveConstants.RELATION_NAME, "")!! + val direction = relDirective.getArgument(DirectiveConstants.RELATION_DIRECTION, null) + ?.let { RelationDirection.valueOf(it) } + ?: RelationDirection.OUT + + return RelationshipInfo( + type, + type.name, + relType, + direction, + relDirective.getMandatoryArgument(DirectiveConstants.RELATION_FROM), + relDirective.getMandatoryArgument(DirectiveConstants.RELATION_TO) + ) + } + + fun create(type: ImplementingTypeDefinition<*>, registry: TypeDefinitionRegistry): RelationshipInfo>? { + val relType = type.getDirectiveArgument(registry, DirectiveConstants.RELATION, DirectiveConstants.RELATION_NAME) + ?: return null + val startField = type.getMandatoryDirectiveArgument(registry, DirectiveConstants.RELATION, DirectiveConstants.RELATION_FROM) + val endField = type.getMandatoryDirectiveArgument(registry, DirectiveConstants.RELATION, DirectiveConstants.RELATION_TO) + val direction = type.getDirectiveArgument(registry, DirectiveConstants.RELATION, DirectiveConstants.RELATION_DIRECTION) + ?.let { RelationDirection.valueOf(it) } + ?: RelationDirection.OUT + return RelationshipInfo(type, type.name, relType, direction, startField, endField) + } + } + + fun createRelation(start: Node, end: Node, addType: Boolean = true, variable: SymbolicName? = null): Relationship { + val labels = if (addType) { + arrayOf(this.relType) + } else { + emptyArray() + } + return when (this.direction) { + RelationDirection.IN -> start.relationshipFrom(end, *labels) + RelationDirection.OUT -> start.relationshipTo(end, *labels) + RelationDirection.BOTH -> start.relationshipBetween(end, *labels) + } + .let { if (variable != null) it.named(variable) else it } + } +} diff --git a/core/src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt b/core/src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt index 87a48638..073be9bd 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt @@ -1,6 +1,5 @@ package org.neo4j.graphql -import graphql.Scalars import graphql.language.* import graphql.schema.* import graphql.schema.idl.RuntimeWiring @@ -16,37 +15,119 @@ import org.neo4j.graphql.handler.relation.CreateRelationTypeHandler import org.neo4j.graphql.handler.relation.DeleteRelationHandler /** - * Contains factory methods to generate an augmented graphql schema + * A class for augmenting a type definition registry and generate the corresponding data fetcher. + * There are factory methods, that can be used to simplify augmenting a schema. + * + * + * Generating the schema is done by invoking the following methods: + * 1. [augmentTypes] + * 2. [registerScalars] + * 3. [registerTypeNameResolver] + * 4. [registerDataFetcher] + * + * Each of these steps can be called manually to enhance an existing [TypeDefinitionRegistry] */ -object SchemaBuilder { +class SchemaBuilder( + val typeDefinitionRegistry: TypeDefinitionRegistry, + val schemaConfig: SchemaConfig = SchemaConfig() +) { + + companion object { + /** + * @param sdl the schema to augment + * @param config defines how the schema should get augmented + * @param dataFetchingInterceptor since this library registers dataFetcher for its augmented methods, these data + * fetchers may be called by other resolver. This interceptor will let you convert a cypher query into real data. + */ + @JvmStatic + @JvmOverloads + fun buildSchema(sdl: String, config: SchemaConfig = SchemaConfig(), dataFetchingInterceptor: DataFetchingInterceptor? = null): GraphQLSchema { + val schemaParser = SchemaParser() + val typeDefinitionRegistry = schemaParser.parse(sdl) + return buildSchema(typeDefinitionRegistry, config, dataFetchingInterceptor) + } + + /** + * @param typeDefinitionRegistry a registry containing all the types, that should be augmented + * @param config defines how the schema should get augmented + * @param dataFetchingInterceptor since this library registers dataFetcher for its augmented methods, these data + * fetchers may be called by other resolver. This interceptor will let you convert a cypher query into real data. + */ + @JvmStatic + @JvmOverloads + fun buildSchema(typeDefinitionRegistry: TypeDefinitionRegistry, config: SchemaConfig = SchemaConfig(), dataFetchingInterceptor: DataFetchingInterceptor? = null): GraphQLSchema { + + val builder = RuntimeWiring.newRuntimeWiring() + val codeRegistryBuilder = GraphQLCodeRegistry.newCodeRegistry() + + val schemaBuilder = SchemaBuilder(typeDefinitionRegistry, config) + schemaBuilder.augmentTypes() + schemaBuilder.registerScalars(builder) + schemaBuilder.registerTypeNameResolver(builder) + schemaBuilder.registerDataFetcher(codeRegistryBuilder, dataFetchingInterceptor) + + return SchemaGenerator().makeExecutableSchema( + typeDefinitionRegistry, + builder.codeRegistry(codeRegistryBuilder).build() + ) + } + } + + private val handler: List + private val neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + + init { + neo4jTypeDefinitionRegistry = getNeo4jEnhancements() + ensureRootQueryTypeExists(typeDefinitionRegistry) + handler = mutableListOf( + CypherDirectiveHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + AugmentFieldHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) + ) + if (schemaConfig.query.enabled) { + handler.add(QueryHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry)) + } + if (schemaConfig.mutation.enabled) { + handler += listOf( + MergeOrUpdateHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + DeleteHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + CreateTypeHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + DeleteRelationHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + CreateRelationTypeHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + CreateRelationHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) + ) + } + } + /** - * @param sdl the schema to augment - * @param config defines how the schema should get augmented - * @param dataFetchingInterceptor since this library registers dataFetcher for its augmented methods, these data - * fetchers may be called by other resolver. This interceptor will let you convert a cypher query into real data. + * Generated additionally query and mutation fields according to the types present in the [typeDefinitionRegistry]. + * This method will also augment relation fields, so filtering and sorting is available for them */ - @JvmStatic - @JvmOverloads - fun buildSchema(sdl: String, config: SchemaConfig = SchemaConfig(), dataFetchingInterceptor: DataFetchingInterceptor? = null): GraphQLSchema { - val schemaParser = SchemaParser() - val typeDefinitionRegistry = schemaParser.parse(sdl) - return buildSchema(typeDefinitionRegistry, config, dataFetchingInterceptor) + fun augmentTypes() { + val queryTypeName = typeDefinitionRegistry.queryTypeName() + val mutationTypeName = typeDefinitionRegistry.mutationTypeName() + val subscriptionTypeName = typeDefinitionRegistry.subscriptionTypeName() + + typeDefinitionRegistry.types().values + .filterIsInstance>() + .filter { it.name != queryTypeName && it.name != mutationTypeName && it.name != subscriptionTypeName } + .forEach { type -> handler.forEach { h -> h.augmentType(type) } } + + // in a second run we enhance all the root fields + typeDefinitionRegistry.types().values + .filterIsInstance>() + .filter { it.name == queryTypeName || it.name == mutationTypeName || it.name == subscriptionTypeName } + .forEach { type -> handler.forEach { h -> h.augmentType(type) } } + + // TODO copy over only the types used in the source schema + typeDefinitionRegistry.merge(neo4jTypeDefinitionRegistry) } /** - * @param typeDefinitionRegistry a registry containing all the types, that should be augmented - * @param config defines how the schema should get augmented - * @param dataFetchingInterceptor since this library registers dataFetcher for its augmented methods, these data - * fetchers may be called by other resolver. This interceptor will let you convert a cypher query into real data. + * Register scalars of this library in the [RuntimeWiring][@param builder] + * @param builder a builder to create a runtime wiring */ - @JvmStatic - @JvmOverloads - fun buildSchema(typeDefinitionRegistry: TypeDefinitionRegistry, config: SchemaConfig = SchemaConfig(), dataFetchingInterceptor: DataFetchingInterceptor? = null): GraphQLSchema { - val enhancedRegistry = typeDefinitionRegistry.merge(getNeo4jEnhancements()) - ensureRootQueryTypeExists(enhancedRegistry) - - val builder = RuntimeWiring.newRuntimeWiring() + fun registerScalars(builder: RuntimeWiring.Builder) { typeDefinitionRegistry.scalars() .filterNot { entry -> GRAPHQL_SPECIFICATION_SCALARS_DEFINITIONS.containsKey(entry.key) } .forEach { (name, definition) -> @@ -62,26 +143,60 @@ object SchemaBuilder { } builder.scalar(scalar) } + } - - enhancedRegistry + /** + * Register type name resolver in the [RuntimeWiring][@param builder] + * @param builder a builder to create a runtime wiring + */ + fun registerTypeNameResolver(builder: RuntimeWiring.Builder) { + typeDefinitionRegistry .getTypes(InterfaceTypeDefinition::class.java) .forEach { typeDefinition -> builder.type(typeDefinition.name) { it.typeResolver { env -> (env.getObject() as? Map) - ?.let { data -> data.get(ProjectionBase.TYPE_NAME) as? String } + ?.let { data -> data[ProjectionBase.TYPE_NAME] as? String } ?.let { typeName -> env.schema.getObjectType(typeName) } } } } - val sourceSchema = SchemaGenerator().makeExecutableSchema(enhancedRegistry, builder.build()) + } - val handler = getHandler(config) + /** + * Register data fetcher in a [GraphQLCodeRegistry][@param codeRegistryBuilder]. + * The data fetcher of this library generate a cypher query and if provided use the dataFetchingInterceptor to run this cypher against a neo4j db. + * @param codeRegistryBuilder a builder to create a code registry + * @param dataFetchingInterceptor a function to convert a cypher string into an object by calling the neo4j db + */ + @JvmOverloads + fun registerDataFetcher( + codeRegistryBuilder: GraphQLCodeRegistry.Builder, + dataFetchingInterceptor: DataFetchingInterceptor?, + typeDefinitionRegistry: TypeDefinitionRegistry = this.typeDefinitionRegistry + ) { + addDataFetcher(typeDefinitionRegistry.queryTypeName(), OperationType.QUERY, dataFetchingInterceptor, codeRegistryBuilder) + addDataFetcher(typeDefinitionRegistry.mutationTypeName(), OperationType.MUTATION, dataFetchingInterceptor, codeRegistryBuilder) + } - var targetSchema = augmentSchema(sourceSchema, handler, config) - targetSchema = addDataFetcher(targetSchema, dataFetchingInterceptor, handler) - return targetSchema + private fun addDataFetcher( + parentType: String, + operationType: OperationType, + dataFetchingInterceptor: DataFetchingInterceptor?, + codeRegistryBuilder: GraphQLCodeRegistry.Builder) { + typeDefinitionRegistry.getType(parentType)?.unwrap() + ?.let { it as? ObjectTypeDefinition } + ?.fieldDefinitions + ?.forEach { field -> + handler.forEach { h -> + h.createDataFetcher(operationType, field)?.let { dataFetcher -> + val interceptedDataFetcher: DataFetcher<*> = dataFetchingInterceptor?.let { + DataFetcher { env -> dataFetchingInterceptor.fetchData(env, dataFetcher) } + } ?: dataFetcher + codeRegistryBuilder.dataFetcher(FieldCoordinates.coordinates(parentType, field.name), interceptedDataFetcher) + } + } + } } private fun ensureRootQueryTypeExists(enhancedRegistry: TypeDefinitionRegistry) { @@ -108,149 +223,9 @@ object SchemaBuilder { }) } - private fun getHandler(schemaConfig: SchemaConfig): List { - val handler = mutableListOf( - CypherDirectiveHandler.Factory(schemaConfig) - ) - if (schemaConfig.query.enabled) { - handler.add(QueryHandler.Factory(schemaConfig)) - } - if (schemaConfig.mutation.enabled) { - handler += listOf( - MergeOrUpdateHandler.Factory(schemaConfig), - DeleteHandler.Factory(schemaConfig), - CreateTypeHandler.Factory(schemaConfig), - DeleteRelationHandler.Factory(schemaConfig), - CreateRelationTypeHandler.Factory(schemaConfig), - CreateRelationHandler.Factory(schemaConfig) - ) - } - return handler - } - - private fun augmentSchema(sourceSchema: GraphQLSchema, handler: List, schemaConfig: SchemaConfig): GraphQLSchema { - val types = sourceSchema.typeMap.toMutableMap() - val env = BuildingEnv(types, sourceSchema, schemaConfig) - val queryTypeName = sourceSchema.queryTypeName() - val mutationTypeName = sourceSchema.mutationTypeName() - val subscriptionTypeName = sourceSchema.subscriptionTypeName() - types.values - .filterIsInstance() - .filter { - !it.name.startsWith("__") - && !it.isNeo4jType() - && it.name != queryTypeName - && it.name != mutationTypeName - && it.name != subscriptionTypeName - } - .forEach { type -> - handler.forEach { h -> h.augmentType(type, env) } - } - - // since new types my be added to `types` we copy the map, to safely modify the entries and later add these - // modified entries back to the `types` - val adjustedTypes = types.toMutableMap() - adjustedTypes.replaceAll { _, sourceType -> - when { - sourceType.name.startsWith("__") -> sourceType - sourceType is GraphQLObjectType -> sourceType.transform { builder -> - builder.clearFields().clearInterfaces() - // to prevent duplicated types in schema - sourceType.interfaces.forEach { builder.withInterface(GraphQLTypeReference(it.name)) } - sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env)) } - } - sourceType is GraphQLInterfaceType -> sourceType.transform { builder -> - builder.clearFields() - sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env)) } - } - else -> sourceType - } - } - types.putAll(adjustedTypes) - - return GraphQLSchema - .newSchema(sourceSchema) - .clearAdditionalTypes() - .query(types[queryTypeName] as? GraphQLObjectType) - .mutation(types[mutationTypeName] as? GraphQLObjectType) - .additionalTypes(types.values.toSet()) - .build() - } - - private fun enhanceRelations(fd: GraphQLFieldDefinition, env: BuildingEnv): GraphQLFieldDefinition { - return fd.transform { fieldBuilder -> - // to prevent duplicated types in schema - fieldBuilder.type(fd.type.ref() as GraphQLOutputType) - - if (!fd.isRelationship() || !fd.type.isList()) { - return@transform - } - - val fieldType = fd.type.inner() as? GraphQLFieldsContainer ?: return@transform - - if (env.schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE) { - - val optionsTypeName = env.addOptions(fieldType) - val optionsType = GraphQLTypeReference(optionsTypeName) - fieldBuilder.argument(input(ProjectionBase.OPTIONS, optionsType)) - - } else { - - if (fd.getArgument(ProjectionBase.FIRST) == null) { - fieldBuilder.argument { a -> a.name(ProjectionBase.FIRST).type(Scalars.GraphQLInt) } - } - if (fd.getArgument(ProjectionBase.OFFSET) == null) { - fieldBuilder.argument { a -> a.name(ProjectionBase.OFFSET).type(Scalars.GraphQLInt) } - } - if (fd.getArgument(ProjectionBase.ORDER_BY) == null) { - env.addOrdering(fieldType)?.let { orderingTypeName -> - val orderType = GraphQLList(GraphQLNonNull(GraphQLTypeReference(orderingTypeName))) - fieldBuilder.argument { a -> a.name(ProjectionBase.ORDER_BY).type(orderType) } - - } - } - - } - - val filterFieldName = if (env.schemaConfig.useWhereFilter) ProjectionBase.WHERE else ProjectionBase.FILTER - if (env.schemaConfig.query.enabled && !env.schemaConfig.query.exclude.contains(fieldType.name) && fd.getArgument(filterFieldName) == null) { - val filterTypeName = env.addFilterType(fieldType) - fieldBuilder.argument(input(filterFieldName, GraphQLTypeReference(filterTypeName))) - } - } - } - - private fun addDataFetcher(sourceSchema: GraphQLSchema, dataFetchingInterceptor: DataFetchingInterceptor?, handler: List): GraphQLSchema { - val codeRegistryBuilder = GraphQLCodeRegistry.newCodeRegistry(sourceSchema.codeRegistry) - addDataFetcher(sourceSchema.queryType, OperationType.QUERY, dataFetchingInterceptor, handler, codeRegistryBuilder) - addDataFetcher(sourceSchema.mutationType, OperationType.MUTATION, dataFetchingInterceptor, handler, codeRegistryBuilder) - return sourceSchema.transform { it.codeRegistry(codeRegistryBuilder.build()) } - } - - private fun addDataFetcher( - rootType: GraphQLObjectType?, - operationType: OperationType, - dataFetchingInterceptor: DataFetchingInterceptor?, - handler: List, - codeRegistryBuilder: GraphQLCodeRegistry.Builder) { - if (rootType == null) return - rootType.fieldDefinitions.forEach { field -> - handler.forEach { h -> - h.createDataFetcher(operationType, field)?.let { dataFetcher -> - val df: DataFetcher<*> = dataFetchingInterceptor?.let { - DataFetcher { env -> - dataFetchingInterceptor.fetchData(env, dataFetcher) - } - } ?: dataFetcher - codeRegistryBuilder.dataFetcher(rootType, field, df) - } - } - } - } - private fun getNeo4jEnhancements(): TypeDefinitionRegistry { - val directivesSdl = javaClass.getResource("/neo4j_types.graphql").readText() + - javaClass.getResource("/lib_directives.graphql").readText() + val directivesSdl = javaClass.getResource("/neo4j_types.graphql")?.readText() + + javaClass.getResource("/lib_directives.graphql")?.readText() val typeDefinitionRegistry = SchemaParser().parse(directivesSdl) neo4jTypeDefinitions .forEach { diff --git a/core/src/main/kotlin/org/neo4j/graphql/Translator.kt b/core/src/main/kotlin/org/neo4j/graphql/Translator.kt index c1c442ea..26a57523 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/Translator.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/Translator.kt @@ -1,7 +1,10 @@ package org.neo4j.graphql import graphql.execution.MergedField -import graphql.language.* +import graphql.language.Document +import graphql.language.Field +import graphql.language.FragmentDefinition +import graphql.language.OperationDefinition import graphql.language.OperationDefinition.Operation.MUTATION import graphql.language.OperationDefinition.Operation.QUERY import graphql.parser.Parser @@ -19,7 +22,7 @@ class Translator(val schema: GraphQLSchema) { @Throws(OptimizedQueryException::class) fun translate(query: String, params: Map = emptyMap(), ctx: QueryContext = QueryContext()): List { val ast = parse(query) // todo preparsedDocumentProvider - val fragments = ast.definitions.filterIsInstance().map { it.name to it }.toMap() + val fragments = ast.definitions.filterIsInstance().associateBy { it.name } return ast.definitions.filterIsInstance() .filter { it.operation == QUERY || it.operation == MUTATION } // todo variableDefinitions, directives, name .flatMap { operationDefinition -> @@ -67,6 +70,7 @@ class Translator(val schema: GraphQLSchema) { return dataFetcher.get(newDataFetchingEnvironment() .mergedField(MergedField.newMergedField(field).build()) + .parentType(operationObjectType) .graphQLSchema(schema) .fragmentsByName(fragments) .context(ctx) diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/AugmentFieldHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/AugmentFieldHandler.kt new file mode 100644 index 00000000..6cfa77ba --- /dev/null +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/AugmentFieldHandler.kt @@ -0,0 +1,71 @@ +package org.neo4j.graphql.handler + +import graphql.language.* +import graphql.schema.DataFetcher +import graphql.schema.idl.TypeDefinitionRegistry +import org.neo4j.graphql.* +import org.neo4j.graphql.handler.projection.ProjectionBase + +/** + * This class augments existing fields on a type and adds filtering and sorting to these fields. + */ +class AugmentFieldHandler( + schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry +) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { + val enhanceRelations = { type.fieldDefinitions.map { fieldDef -> fieldDef.transform { augmentRelation(it, fieldDef) } } } + + val rewritten = when (type) { + is ObjectTypeDefinition -> type.transform { it.fieldDefinitions(enhanceRelations()) } + is InterfaceTypeDefinition -> type.transform { it.definitions(enhanceRelations()) } + else -> return + } + + typeDefinitionRegistry.remove(rewritten.name, rewritten) + typeDefinitionRegistry.add(rewritten) + } + + private fun augmentRelation(fieldBuilder: FieldDefinition.Builder, field: FieldDefinition) { + if (!field.isRelationship() || !field.type.isList()) { + return + } + + val fieldType = field.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return + + if (schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE) { + + val optionsTypeName = addOptions(fieldType) + if (field.inputValueDefinitions.find { it.name == ProjectionBase.OPTIONS } == null) { + fieldBuilder.inputValueDefinition(input(ProjectionBase.OPTIONS, TypeName(optionsTypeName))) + } + + } else { + + if (field.inputValueDefinitions.find { it.name == ProjectionBase.FIRST } == null) { + fieldBuilder.inputValueDefinition(input(ProjectionBase.FIRST, TypeInt)) + } + if (field.inputValueDefinitions.find { it.name == ProjectionBase.OFFSET } == null) { + fieldBuilder.inputValueDefinition(input(ProjectionBase.OFFSET, TypeInt)) + } + if (field.inputValueDefinitions.find { it.name == ProjectionBase.ORDER_BY } == null) { + addOrdering(fieldType)?.let { orderingTypeName -> + val orderType = ListType(NonNullType(TypeName(orderingTypeName))) + fieldBuilder.inputValueDefinition(input(ProjectionBase.ORDER_BY, orderType)) + } + } + } + + val filterFieldName = if (schemaConfig.useWhereFilter) ProjectionBase.WHERE else ProjectionBase.FILTER + if (schemaConfig.query.enabled && !schemaConfig.query.exclude.contains(fieldType.name) && field.inputValueDefinitions.find { it.name == filterFieldName } == null) { + val filterTypeName = addFilterType(fieldType) + fieldBuilder.inputValueDefinition(input(filterFieldName, TypeName(filterTypeName))) + } + + } + + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? = null +} + diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcher.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcher.kt index 8a8715bb..58af868d 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcher.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcher.kt @@ -5,6 +5,7 @@ import graphql.language.VariableReference import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLType import org.neo4j.cypherdsl.core.Statement import org.neo4j.cypherdsl.core.renderer.Configuration import org.neo4j.cypherdsl.core.renderer.Renderer @@ -16,13 +17,15 @@ import org.neo4j.graphql.handler.projection.ProjectionBase /** * The is a base class for the implementation of graphql data fetcher used in this project */ -abstract class BaseDataFetcher(val fieldDefinition: GraphQLFieldDefinition, schemaConfig: SchemaConfig) : ProjectionBase(schemaConfig), DataFetcher { +abstract class BaseDataFetcher(schemaConfig: SchemaConfig) : ProjectionBase(schemaConfig), DataFetcher { - override fun get(env: DataFetchingEnvironment?): Cypher { - val field = env?.mergedField?.singleField + private var init = false + + override fun get(env: DataFetchingEnvironment): Cypher { + val field = env.mergedField?.singleField ?: throw IllegalAccessException("expect one filed in environment.mergedField") - require(field.name == fieldDefinition.name) { "Handler for ${fieldDefinition.name} cannot handle ${field.name}" } val variable = field.aliasOrName().decapitalize() + prepareDataFetcher(env.fieldDefinition, env.parentType) val statement = generateCypher(variable, field, env) val query = Renderer.getRenderer(Configuration @@ -36,7 +39,21 @@ abstract class BaseDataFetcher(val fieldDefinition: GraphQLFieldDefinition, sche (value as? VariableReference)?.let { env.variables[it.name] } ?: value } - return Cypher(query, params, fieldDefinition.type, variable = field.aliasOrName()) + return Cypher(query, params, env.fieldDefinition.type, variable = field.aliasOrName()) + } + + /** + * called after the schema is generated but before the 1st call + */ + private fun prepareDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { + if (init) { + return + } + init = true + initDataFetcher(fieldDefinition, parentType) + } + + protected open fun initDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { } protected abstract fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcherForContainer.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcherForContainer.kt index c406d4b5..3d0406f1 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcherForContainer.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcherForContainer.kt @@ -5,6 +5,7 @@ import graphql.language.ArrayValue import graphql.language.ObjectValue import graphql.schema.GraphQLFieldDefinition import graphql.schema.GraphQLFieldsContainer +import graphql.schema.GraphQLType import org.neo4j.cypherdsl.core.* import org.neo4j.cypherdsl.core.Cypher.* import org.neo4j.graphql.* @@ -12,16 +13,15 @@ import org.neo4j.graphql.* /** * This is a base class for all Node or Relation related data fetcher. */ -abstract class BaseDataFetcherForContainer( - val type: GraphQLFieldsContainer, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseDataFetcher(fieldDefinition, schemaConfig) { +abstract class BaseDataFetcherForContainer(schemaConfig: SchemaConfig) : BaseDataFetcher(schemaConfig) { + lateinit var type: GraphQLFieldsContainer val propertyFields: MutableMap List?> = mutableMapOf() val defaultFields: MutableMap = mutableMapOf() - init { + override fun initDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { + type = fieldDefinition.type.inner() as? GraphQLFieldsContainer + ?: throw IllegalStateException("expect type of field ${parentType.name()}.${fieldDefinition.name} to be GraphQLFieldsContainer, but was ${fieldDefinition.type.name()}") fieldDefinition .arguments .filterNot { listOf(FIRST, OFFSET, ORDER_BY, NATIVE_ID, OPTIONS).contains(it.name) } diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/CreateTypeHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/CreateTypeHandler.kt index cbc8ca32..55f8d474 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/CreateTypeHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/CreateTypeHandler.kt @@ -1,7 +1,13 @@ package org.neo4j.graphql.handler import graphql.language.Field -import graphql.schema.* +import graphql.language.FieldDefinition +import graphql.language.ImplementingTypeDefinition +import graphql.language.InterfaceTypeDefinition +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment +import graphql.schema.GraphQLObjectType +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Statement import org.neo4j.cypherdsl.core.StatementBuilder import org.neo4j.graphql.* @@ -10,58 +16,56 @@ import org.neo4j.graphql.* * This class handles all the logic related to the creation of nodes. * This includes the augmentation of the create<Node>-mutator and the related cypher generation */ -class CreateTypeHandler private constructor( - type: GraphQLFieldsContainer, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) { +class CreateTypeHandler private constructor(schemaConfig: SchemaConfig) : BaseDataFetcherForContainer(schemaConfig) { - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { if (!canHandle(type)) { return } val relevantFields = getRelevantFields(type) - val fieldDefinition = buildingEnv - .buildFieldDefinition("create", type, relevantFields, nullableResult = false) + val fieldDefinition = buildFieldDefinition("create", type, relevantFields, nullableResult = false) .build() - buildingEnv.addMutationField(fieldDefinition) + addMutationField(fieldDefinition) } - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.MUTATION) { return null } if (fieldDefinition.cypherDirective() != null) { return null } - val type = fieldDefinition.type.inner() as? GraphQLObjectType - ?: return null + val type = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandle(type)) { return null } - return when { - fieldDefinition.name == "create${type.name}" -> CreateTypeHandler(type, fieldDefinition, schemaConfig) + return when (fieldDefinition.name) { + "create${type.name}" -> CreateTypeHandler(schemaConfig) else -> null } } - private fun getRelevantFields(type: GraphQLFieldsContainer): List { + private fun getRelevantFields(type: ImplementingTypeDefinition<*>): List { return type - .relevantFields() + .getScalarFields() .filter { !it.isNativeId() } } - private fun canHandle(type: GraphQLFieldsContainer): Boolean { + private fun canHandle(type: ImplementingTypeDefinition<*>): Boolean { val typeName = type.name - if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName)) { + if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName) || isRootType(type)) { return false } - if (type !is GraphQLObjectType) { + if (type is InterfaceTypeDefinition) { return false } - if ((type as GraphQLDirectiveContainer).isRelationType()) { + if (type.relationship() != null) { // relations are handled by the CreateRelationTypeHandler return false } diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/CypherDirectiveHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/CypherDirectiveHandler.kt index 1b72e7fe..923dc66b 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/CypherDirectiveHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/CypherDirectiveHandler.kt @@ -1,10 +1,11 @@ package org.neo4j.graphql.handler import graphql.language.Field +import graphql.language.FieldDefinition import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment -import graphql.schema.GraphQLFieldDefinition import graphql.schema.GraphQLFieldsContainer +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Functions import org.neo4j.cypherdsl.core.Statement import org.neo4j.graphql.* @@ -12,25 +13,25 @@ import org.neo4j.graphql.* /** * This class handles all logic related to custom Cypher queries declared by fields with a @cypher directive */ -class CypherDirectiveHandler( - private val type: GraphQLFieldsContainer?, - private val isQuery: Boolean, - private val cypherDirective: CypherDirective, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig) - : BaseDataFetcher(fieldDefinition, schemaConfig) { - - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { - val cypherDirective = fieldDefinition.cypherDirective() ?: return null - val type = fieldDefinition.type.inner() as? GraphQLFieldsContainer +class CypherDirectiveHandler(private val isQuery: Boolean, schemaConfig: SchemaConfig) : BaseDataFetcher(schemaConfig) { + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { + fieldDefinition.cypherDirective() ?: return null val isQuery = operationType == OperationType.QUERY - return CypherDirectiveHandler(type, isQuery, cypherDirective, fieldDefinition, schemaConfig) + return CypherDirectiveHandler(isQuery, schemaConfig) } } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { + val fieldDefinition = env.fieldDefinition + val type = fieldDefinition.type.inner() as? GraphQLFieldsContainer + val cypherDirective = fieldDefinition.cypherDirective() + ?: throw IllegalStateException("Expect field ${env.logField()} to have @cypher directive present") val query = if (isQuery) { val nestedQuery = cypherDirective(variable, fieldDefinition, field, cypherDirective) diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/DeleteHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/DeleteHandler.kt index 02e78182..8761235e 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/DeleteHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/DeleteHandler.kt @@ -1,7 +1,14 @@ package org.neo4j.graphql.handler import graphql.language.Field -import graphql.schema.* +import graphql.language.FieldDefinition +import graphql.language.ImplementingTypeDefinition +import graphql.language.TypeName +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment +import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLType +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Node import org.neo4j.cypherdsl.core.Relationship import org.neo4j.cypherdsl.core.Statement @@ -12,57 +19,62 @@ import org.neo4j.graphql.* * This class handles all the logic related to the deletion of nodes. * This includes the augmentation of the delete<Node>-mutator and the related cypher generation */ -class DeleteHandler private constructor( - type: GraphQLFieldsContainer, - private val idField: GraphQLFieldDefinition, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig, - private val isRelation: Boolean = type.isRelationType() -) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) { +class DeleteHandler private constructor(schemaConfig: SchemaConfig) : BaseDataFetcherForContainer(schemaConfig) { - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { + private lateinit var idField: GraphQLFieldDefinition + private var isRelation: Boolean = false + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { if (!canHandle(type)) { return } val idField = type.getIdField() ?: return - val fieldDefinition = buildingEnv - .buildFieldDefinition("delete", type, listOf(idField), nullableResult = true) - .description("Deletes ${type.name} and returns the type itself") - .type(type.ref() as GraphQLOutputType) + val fieldDefinition = buildFieldDefinition("delete", type, listOf(idField), nullableResult = true) + .description("Deletes ${type.name} and returns the type itself".asDescription()) + .type(TypeName(type.name)) .build() - buildingEnv.addMutationField(fieldDefinition) + addMutationField(fieldDefinition) } - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.MUTATION) { return null } if (fieldDefinition.cypherDirective() != null) { return null } - val type = fieldDefinition.type as? GraphQLFieldsContainer - ?: return null + val type = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandle(type)) { return null } - val idField = type.getIdField() ?: return null + type.getIdField() ?: return null return when (fieldDefinition.name) { - "delete${type.name}" -> DeleteHandler(type, idField, fieldDefinition, schemaConfig) + "delete${type.name}" -> DeleteHandler(schemaConfig) else -> null } } - private fun canHandle(type: GraphQLFieldsContainer): Boolean { + private fun canHandle(type: ImplementingTypeDefinition<*>): Boolean { val typeName = type.name - if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName)) { + if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName) || isRootType(type)) { return false } return type.getIdField() != null } } + override fun initDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { + super.initDataFetcher(fieldDefinition, parentType) + idField = type.getIdField() ?: throw IllegalStateException("Cannot resolve id field for type ${type.name}") + isRelation = type.isRelationType() + } + override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { val idArg = field.arguments.first { it.name == idField.name } diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/MergeOrUpdateHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/MergeOrUpdateHandler.kt index c9496fad..6ceba797 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/MergeOrUpdateHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/MergeOrUpdateHandler.kt @@ -1,10 +1,13 @@ package org.neo4j.graphql.handler import graphql.language.Field +import graphql.language.FieldDefinition +import graphql.language.ImplementingTypeDefinition import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLFieldDefinition -import graphql.schema.GraphQLFieldsContainer +import graphql.schema.GraphQLType +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Node import org.neo4j.cypherdsl.core.Relationship import org.neo4j.cypherdsl.core.Statement @@ -15,62 +18,59 @@ import org.neo4j.graphql.* * This class handles all the logic related to the updating of nodes. * This includes the augmentation of the update<Node> and merge<Node>-mutator and the related cypher generation */ -class MergeOrUpdateHandler private constructor( - type: GraphQLFieldsContainer, - private val merge: Boolean, - private val idField: GraphQLFieldDefinition, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig, - private val isRelation: Boolean = type.isRelationType() -) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) { - - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { +class MergeOrUpdateHandler private constructor(private val merge: Boolean, schemaConfig: SchemaConfig) : BaseDataFetcherForContainer(schemaConfig) { + + private lateinit var idField: GraphQLFieldDefinition + private var isRelation: Boolean = false + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { if (!canHandle(type)) { return } - val relevantFields = type.relevantFields() - val mergeField = buildingEnv - .buildFieldDefinition("merge", type, relevantFields, nullableResult = false) + val relevantFields = type.getScalarFields() + val mergeField = buildFieldDefinition("merge", type, relevantFields, nullableResult = false) .build() - buildingEnv.addMutationField(mergeField) + addMutationField(mergeField) - val updateField = buildingEnv - .buildFieldDefinition("update", type, relevantFields, nullableResult = true) + val updateField = buildFieldDefinition("update", type, relevantFields, nullableResult = true) .build() - buildingEnv.addMutationField(updateField) + addMutationField(updateField) } - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.MUTATION) { return null } if (fieldDefinition.cypherDirective() != null) { return null } - val type = fieldDefinition.type.inner() as? GraphQLFieldsContainer - ?: return null + val type = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandle(type)) { return null } - val idField = type.getIdField() ?: return null + type.getIdField() ?: return null return when (fieldDefinition.name) { - "merge${type.name}" -> MergeOrUpdateHandler(type, true, idField, fieldDefinition, schemaConfig) - "update${type.name}" -> MergeOrUpdateHandler(type, false, idField, fieldDefinition, schemaConfig) + "merge${type.name}" -> MergeOrUpdateHandler(true, schemaConfig) + "update${type.name}" -> MergeOrUpdateHandler(false, schemaConfig) else -> null } } - private fun canHandle(type: GraphQLFieldsContainer): Boolean { + private fun canHandle(type: ImplementingTypeDefinition<*>): Boolean { val typeName = type.name - if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName)) { + if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName) || isRootType(type)) { return false } if (type.getIdField() == null) { return false } - if (type.relevantFields().none { !it.isID() }) { + if (type.getScalarFields().none { !it.type.inner().isID() }) { // nothing to update (except ID) return false } @@ -78,9 +78,15 @@ class MergeOrUpdateHandler private constructor( } } - init { + override fun initDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { + super.initDataFetcher(fieldDefinition, parentType) + + idField = type.getIdField() ?: throw IllegalStateException("Cannot resolve id field for type ${type.name}") + isRelation = type.isRelationType() + defaultFields.clear() // for merge or updates we do not reset to defaults propertyFields.remove(idField.name) // id should not be updated + } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt index c0cf25bc..c8800942 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt @@ -1,8 +1,9 @@ package org.neo4j.graphql.handler -import graphql.Scalars -import graphql.language.Field -import graphql.schema.* +import graphql.language.* +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment +import graphql.schema.idl.TypeDefinitionRegistry import org.atteo.evo.inflector.English import org.neo4j.cypherdsl.core.Cypher.* import org.neo4j.cypherdsl.core.Statement @@ -13,58 +14,56 @@ import org.neo4j.graphql.handler.filter.OptimizedFilterHandler * This class handles all the logic related to the querying of nodes and relations. * This includes the augmentation of the query-fields and the related cypher generation */ -class QueryHandler private constructor( - type: GraphQLFieldsContainer, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) { - - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { +class QueryHandler private constructor(schemaConfig: SchemaConfig) : BaseDataFetcherForContainer(schemaConfig) { + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { if (!canHandle(type)) { return } val typeName = type.name val relevantFields = getRelevantFields(type) - val filterTypeName = buildingEnv.addFilterType(type) + val filterTypeName = addFilterType(type) val arguments = if (schemaConfig.useWhereFilter) { - listOf(input(WHERE, GraphQLTypeReference(filterTypeName))) + listOf(input(WHERE, TypeName(filterTypeName))) } else { - buildingEnv.getInputValueDefinitions(relevantFields, { true }) + - input(FILTER, GraphQLTypeReference(filterTypeName)) + getInputValueDefinitions(relevantFields, { true }) + + input(FILTER, TypeName(filterTypeName)) } var fieldName = if (schemaConfig.capitalizeQueryFields) typeName else typeName.decapitalize() if (schemaConfig.pluralizeFields) { fieldName = English.plural(fieldName) } - val builder = GraphQLFieldDefinition + val builder = FieldDefinition .newFieldDefinition() .name(fieldName) - .arguments(arguments) - .type(GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLTypeReference(type.name))))) + .inputValueDefinitions(arguments.toMutableList()) + .type(NonNullType(ListType(NonNullType(TypeName(type.name))))) if (schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE) { - val optionsTypeName = buildingEnv.addOptions(type) - val optionsType = GraphQLTypeReference(optionsTypeName) - builder.argument(input(OPTIONS, optionsType)) + val optionsTypeName = addOptions(type) + builder.inputValueDefinition(input(OPTIONS, TypeName(optionsTypeName))) } else { builder - .argument(input(FIRST, Scalars.GraphQLInt)) - .argument(input(OFFSET, Scalars.GraphQLInt)) + .inputValueDefinition(input(FIRST, TypeInt)) + .inputValueDefinition(input(OFFSET, TypeInt)) - val orderingTypeName = buildingEnv.addOrdering(type) + val orderingTypeName = addOrdering(type) if (orderingTypeName != null) { - val orderType = GraphQLList(GraphQLNonNull(GraphQLTypeReference(orderingTypeName))) - builder.argument(input(ORDER_BY, orderType)) + builder.inputValueDefinition(input(ORDER_BY, ListType(NonNullType(TypeName(orderingTypeName))))) } } val def = builder.build() - buildingEnv.addQueryField(def) + addQueryField(def) } - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.QUERY) { return null } @@ -72,17 +71,16 @@ class QueryHandler private constructor( if (cypherDirective != null) { return null } - val type = fieldDefinition.type.inner() as? GraphQLFieldsContainer - ?: return null + val type = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandle(type)) { return null } - return QueryHandler(type, fieldDefinition, schemaConfig) + return QueryHandler(schemaConfig) } - private fun canHandle(type: GraphQLFieldsContainer): Boolean { - val typeName = type.innerName() - if (!schemaConfig.query.enabled || schemaConfig.query.exclude.contains(typeName)) { + private fun canHandle(type: ImplementingTypeDefinition<*>): Boolean { + val typeName = type.name + if (!schemaConfig.query.enabled || schemaConfig.query.exclude.contains(typeName) || isRootType(type)) { return false } if (getRelevantFields(type).isEmpty() && !hasRelationships(type)) { @@ -91,16 +89,18 @@ class QueryHandler private constructor( return true } - private fun hasRelationships(type: GraphQLFieldsContainer): Boolean = type.fieldDefinitions.any { it.isRelationship() } + private fun hasRelationships(type: ImplementingTypeDefinition<*>): Boolean = type.fieldDefinitions.any { it.isRelationship() } - private fun getRelevantFields(type: GraphQLFieldsContainer): List { + private fun getRelevantFields(type: ImplementingTypeDefinition<*>): List { return type - .relevantFields() + .getScalarFields() .filter { it.dynamicPrefix() == null } // TODO currently we do not support filtering on dynamic properties } } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { + val fieldDefinition = env.fieldDefinition + val type = env.typeAsContainer() val (propertyContainer, match) = when { type.isRelationType() -> anyNode().relationshipTo(anyNode(), type.label()).named(variable) diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/projection/ProjectionBase.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/projection/ProjectionBase.kt index 10cc73d1..9ee44e0e 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/projection/ProjectionBase.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/projection/ProjectionBase.kt @@ -373,7 +373,12 @@ open class ProjectionBase( return skipLimit.slice(fieldType.isList(), comprehension) } - private fun relationshipInfoInCorrectDirection(fieldObjectType: GraphQLFieldsContainer, relInfo0: RelationshipInfo, parent: GraphQLFieldsContainer, relDirectiveField: RelationshipInfo?): RelationshipInfo { + private fun relationshipInfoInCorrectDirection( + fieldObjectType: GraphQLFieldsContainer, + relInfo0: RelationshipInfo, + parent: GraphQLFieldsContainer, + relDirectiveField: RelationshipInfo? + ): RelationshipInfo { val startField = fieldObjectType.getFieldDefinition(relInfo0.startField)!! val endField = fieldObjectType.getFieldDefinition(relInfo0.endField)!! val startFieldTypeName = startField.type.innerName() @@ -413,11 +418,7 @@ open class ProjectionBase( relInfo.endField -> Triple(anyNode(), node, node) else -> throw IllegalArgumentException("type ${parent.name} does not have a matching field with name ${fieldDefinition.name}") } - val rel = when (relInfo.direction) { - RelationDirection.IN -> start.relationshipFrom(end).named(variable) - RelationDirection.OUT -> start.relationshipTo(end).named(variable) - RelationDirection.BOTH -> start.relationshipBetween(end).named(variable) - } + val rel = relInfo.createRelation(start, end, false,variable) return head(CypherDSL.listBasedOn(rel).returning(target.project(projectFields(target, field, fieldDefinition.type as GraphQLFieldsContainer, env)))) } @@ -426,8 +427,8 @@ open class ProjectionBase( val nodeType = fieldType.getInnerFieldsContainer() // todo combine both nestings if rel-entity - val relDirectiveObject = (nodeType as? GraphQLDirectiveContainer)?.getDirective(DirectiveConstants.RELATION)?.let { relDetails(nodeType, it) } - val relDirectiveField = fieldDefinition.getDirective(DirectiveConstants.RELATION)?.let { relDetails(nodeType, it) } + val relDirectiveObject = (nodeType as? GraphQLDirectiveContainer)?.getDirective(DirectiveConstants.RELATION)?.let { RelationshipInfo.create(nodeType, it) } + val relDirectiveField = fieldDefinition.getDirective(DirectiveConstants.RELATION)?.let { RelationshipInfo.create(nodeType, it) } val (relInfo0, isRelFromType) = relDirectiveObject?.let { it to true } diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/BaseRelationHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/BaseRelationHandler.kt index 6810b183..ea81fdfa 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/BaseRelationHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/BaseRelationHandler.kt @@ -1,8 +1,11 @@ package org.neo4j.graphql.handler.relation -import graphql.Scalars -import graphql.language.Argument -import graphql.schema.* +import graphql.language.* +import graphql.schema.DataFetcher +import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLFieldsContainer +import graphql.schema.GraphQLType +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Condition import org.neo4j.cypherdsl.core.Node import org.neo4j.graphql.* @@ -11,57 +14,52 @@ import org.neo4j.graphql.handler.BaseDataFetcherForContainer /** * This is a base class for all handler acting on relations / edges */ -abstract class BaseRelationHandler( - type: GraphQLFieldsContainer, - val relation: RelationshipInfo, - private val startId: RelationshipInfo.RelatedField, - private val endId: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig) - : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) { - - init { - propertyFields.remove(startId.argumentName) - propertyFields.remove(endId.argumentName) - } +abstract class BaseRelationHandler(val prefix: String, schemaConfig: SchemaConfig) : BaseDataFetcherForContainer(schemaConfig) { + + lateinit var relation: RelationshipInfo + lateinit var startId: RelatedField + lateinit var endId: RelatedField + + data class RelatedField( + val argumentName: String, + val field: GraphQLFieldDefinition, + val declaringType: GraphQLFieldsContainer + ) - abstract class BaseRelationFactory(val prefix: String, schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { + abstract class BaseRelationFactory( + val prefix: String, + schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { protected fun buildFieldDefinition( - source: GraphQLFieldsContainer, - targetField: GraphQLFieldDefinition, + source: ImplementingTypeDefinition<*>, + targetField: FieldDefinition, nullableResult: Boolean - ): GraphQLFieldDefinition.Builder? { + ): FieldDefinition.Builder? { - val targetType = targetField.type.getInnerFieldsContainer() - val sourceIdField = source.getIdField() - val targetIdField = targetType.getIdField() - if (sourceIdField == null || targetIdField == null) { - return null - } + val (sourceIdField, _) = getRelationFields(source, targetField) ?: return null val targetFieldName = targetField.name.capitalize() - val idType = GraphQLNonNull(Scalars.GraphQLID) - val targetIDType = if (targetField.type.isList()) GraphQLNonNull(GraphQLList(idType)) else idType + val idType = NonNullType(TypeID) + val targetIDType = if (targetField.type.isList()) NonNullType(ListType(idType)) else idType - var type: GraphQLOutputType = source + var type: Type<*> = TypeName(source.name) if (!nullableResult) { - type = GraphQLNonNull(type) + type = NonNullType(type) } - return GraphQLFieldDefinition.newFieldDefinition() + return FieldDefinition.newFieldDefinition() .name("$prefix${source.name}$targetFieldName") - .argument(input(sourceIdField.name, idType)) - .argument(input(targetField.name, targetIDType)) - .type(type.ref() as GraphQLOutputType) + .inputValueDefinition(input(sourceIdField.name, idType)) + .inputValueDefinition(input(targetField.name, targetIDType)) + .type(type) } - protected fun canHandleType(type: GraphQLFieldsContainer): Boolean { - if (type !is GraphQLObjectType) { - return false - } - if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(type.name)) { + protected fun canHandleType(type: ImplementingTypeDefinition<*>): Boolean { + if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(type.name) || isRootType(type)) { // TODO we do not mutate the node but the relation, I think this check should be different return false } @@ -71,9 +69,9 @@ abstract class BaseRelationHandler( return true } - protected fun canHandleField(targetField: GraphQLFieldDefinition): Boolean { - val type = targetField.type.inner() as? GraphQLObjectType ?: return false - if (targetField.getDirective(DirectiveConstants.RELATION) == null) { + protected fun canHandleField(targetField: FieldDefinition): Boolean { + val type = targetField.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return false + if (!targetField.hasDirective(DirectiveConstants.RELATION)) { return false } if (type.getIdField() == null) { @@ -82,15 +80,14 @@ abstract class BaseRelationHandler( return true } - final override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + final override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.MUTATION) { return null } if (fieldDefinition.cypherDirective() != null) { return null } - val sourceType = fieldDefinition.type.inner() as? GraphQLFieldsContainer - ?: return null + val sourceType = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandleType(sourceType)) { return null } @@ -107,32 +104,77 @@ abstract class BaseRelationHandler( if (!canHandleField(targetField)) { return null } - val relation = sourceType.relationshipFor(targetField.name) ?: return null + if (!sourceType.hasRelationshipFor(targetField.name)) { + return null + } + + val (_, _) = getRelationFields(sourceType, targetField) ?: return null + return createDataFetcher() + } - val targetType = targetField.type.getInnerFieldsContainer() - val sourceIdField = sourceType.getIdField() + private fun ImplementingTypeDefinition<*>.hasRelationshipFor(name: String): Boolean { + val field = getFieldDefinition(name) + ?: throw IllegalArgumentException("$name is not defined on ${this.name}") + val fieldObjectType = field.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return false + + val target = getDirectiveArgument(DirectiveConstants.RELATION, DirectiveConstants.RELATION_TO, null) + if (target != null) { + if (fieldObjectType.getFieldDefinition(target)?.name == this.name) return true + } else { + if (fieldObjectType.getDirective(DirectiveConstants.RELATION) != null) return true + if (field.getDirective(DirectiveConstants.RELATION) != null) return true + } + return false + } + + abstract fun createDataFetcher(): DataFetcher? + + private fun getRelationFields(source: ImplementingTypeDefinition<*>, targetField: FieldDefinition): Pair? { + val targetType = targetField.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null + val sourceIdField = source.getIdField() val targetIdField = targetType.getIdField() if (sourceIdField == null || targetIdField == null) { return null } - val startIdField = RelationshipInfo.RelatedField(sourceIdField.name, sourceIdField, sourceType) - val endIdField = RelationshipInfo.RelatedField(targetField.name, targetIdField, targetType) - return createDataFetcher(sourceType, relation, startIdField, endIdField, fieldDefinition) + return sourceIdField to targetIdField } + } + + override fun initDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { + super.initDataFetcher(fieldDefinition, parentType) + + initRelation(fieldDefinition) + + propertyFields.remove(startId.argumentName) + propertyFields.remove(endId.argumentName) + } + + protected open fun initRelation(fieldDefinition: GraphQLFieldDefinition) { + val p = "$prefix${type.name}" + + val targetField = fieldDefinition.name + .removePrefix(p) + .decapitalize() + .let { + type.getFieldDefinition(it) ?: throw IllegalStateException("Cannot find field $it on type ${type.name}") + } + - abstract fun createDataFetcher( - sourceType: GraphQLFieldsContainer, - relation: RelationshipInfo, - startIdField: RelationshipInfo.RelatedField, - endIdField: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition - ): DataFetcher? + relation = type.relationshipFor(targetField.name) + ?: throw IllegalStateException("Cannot resolve relationship for ${targetField.name} on type ${type.name}") + val targetType = targetField.type.getInnerFieldsContainer() + val sourceIdField = type.getIdField() + ?: throw IllegalStateException("Cannot find id field for type ${type.name}") + val targetIdField = targetType.getIdField() + ?: throw IllegalStateException("Cannot find id field for type ${targetType.name}") + startId = RelatedField(sourceIdField.name, sourceIdField, type) + endId = RelatedField(targetField.name, targetIdField, targetType) } fun getRelationSelect(start: Boolean, arguments: Map): Pair { val relFieldName: String - val idField: RelationshipInfo.RelatedField + val idField: RelatedField if (start) { relFieldName = relation.startField idField = startId diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationHandler.kt index f649b220..22d2ee8e 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationHandler.kt @@ -1,10 +1,10 @@ package org.neo4j.graphql.handler.relation import graphql.language.Field +import graphql.language.ImplementingTypeDefinition import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment -import graphql.schema.GraphQLFieldDefinition -import graphql.schema.GraphQLFieldsContainer +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Statement import org.neo4j.cypherdsl.core.StatementBuilder.OngoingUpdate import org.neo4j.graphql.* @@ -13,20 +13,25 @@ import org.neo4j.graphql.* * This class handles all the logic related to the creation of relations starting from an existing node. * This includes the augmentation of the add<Edge>-mutator and the related cypher generation */ -class CreateRelationHandler private constructor( - type: GraphQLFieldsContainer, - relation: RelationshipInfo, - startId: RelationshipInfo.RelatedField, - endId: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseRelationHandler(type, relation, startId, endId, fieldDefinition, schemaConfig) { - - class Factory(schemaConfig: SchemaConfig) : BaseRelationFactory("add", schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { +class CreateRelationHandler private constructor(schemaConfig: SchemaConfig) : BaseRelationHandler("add", schemaConfig) { + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : BaseRelationFactory("add", schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { + if (!canHandleType(type)) { return } + + val richRelationTypes = typeDefinitionRegistry.types().values + .filterIsInstance>() + .filter { it.getDirective(DirectiveConstants.RELATION) != null } + .associate { it.getDirectiveArgument(DirectiveConstants.RELATION, DirectiveConstants.RELATION_NAME, null)!! to it.name } + + type.fieldDefinitions .filter { canHandleField(it) } .mapNotNull { targetField -> @@ -35,36 +40,30 @@ class CreateRelationHandler private constructor( val relationType = targetField .getDirectiveArgument(DirectiveConstants.RELATION, DirectiveConstants.RELATION_NAME, null) - ?.let { buildingEnv.getTypeForRelation(it) } + ?.let { it -> (richRelationTypes[it]) } + ?.let { typeDefinitionRegistry.getUnwrappedType(it) as? ImplementingTypeDefinition } relationType ?.fieldDefinitions - ?.filter { it.type.isScalar() && !it.isID() } - ?.forEach { builder.argument(input(it.name, it.type)) } + ?.filter { it.type.inner().isScalar() && !it.type.inner().isID() } + ?.forEach { builder.inputValueDefinition(input(it.name, it.type)) } - buildingEnv.addMutationField(builder.build()) + addMutationField(builder.build()) } } } - override fun createDataFetcher( - sourceType: GraphQLFieldsContainer, - relation: RelationshipInfo, - startIdField: RelationshipInfo.RelatedField, - endIdField: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition - ): DataFetcher { - return CreateRelationHandler(sourceType, relation, startIdField, endIdField, fieldDefinition, schemaConfig) + override fun createDataFetcher(): DataFetcher { + return CreateRelationHandler(schemaConfig) } - } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { val properties = properties(variable, field.arguments) - val arguments = field.arguments.map { it.name to it }.toMap() + val arguments = field.arguments.associateBy { it.name } val (startNode, startWhere) = getRelationSelect(true, arguments) val (endNode, endWhere) = getRelationSelect(false, arguments) diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationTypeHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationTypeHandler.kt index bbd932ae..541d6b94 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationTypeHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationTypeHandler.kt @@ -1,7 +1,11 @@ package org.neo4j.graphql.handler.relation import graphql.language.Field +import graphql.language.FieldDefinition +import graphql.language.ImplementingTypeDefinition +import graphql.language.InterfaceTypeDefinition import graphql.schema.* +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Cypher.name import org.neo4j.cypherdsl.core.Statement import org.neo4j.cypherdsl.core.StatementBuilder @@ -11,83 +15,71 @@ import org.neo4j.graphql.* * This class handles all the logic related to the creation of relations. * This includes the augmentation of the create<Edge>-mutator and the related cypher generation */ -class CreateRelationTypeHandler private constructor( - type: GraphQLFieldsContainer, - relation: RelationshipInfo, - startId: RelationshipInfo.RelatedField, - endId: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseRelationHandler(type, relation, startId, endId, fieldDefinition, schemaConfig) { - - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { +class CreateRelationTypeHandler private constructor(schemaConfig: SchemaConfig) : BaseRelationHandler("create", schemaConfig) { + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { if (!canHandle(type)) { return } - val relation = type.relationship()!! - val startIdField = relation.getStartFieldId() - val endIdField = relation.getEndFieldId() + val relation = type.relationship() ?: return + val startIdField = getRelatedIdField(relation, relation.startField) + val endIdField = getRelatedIdField(relation, relation.endField) if (startIdField == null || endIdField == null) { return } - val relevantFields = getRelevantFields(type) val createArgs = getRelevantFields(type) .filter { !it.isNativeId() } .filter { it.name != startIdField.argumentName } .filter { it.name != endIdField.argumentName } - val builder = buildingEnv - .buildFieldDefinition("create", type, relevantFields, nullableResult = false) - .argument(input(startIdField.argumentName, startIdField.field.type)) - .argument(input(endIdField.argumentName, endIdField.field.type)) - - createArgs.forEach { builder.argument(input(it.name, it.type)) } + val builder = + buildFieldDefinition("create", type, createArgs, nullableResult = false) + .inputValueDefinition(input(startIdField.argumentName, startIdField.field.type)) + .inputValueDefinition(input(endIdField.argumentName, endIdField.field.type)) - buildingEnv.addMutationField(builder.build()) + addMutationField(builder.build()) } - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.MUTATION) { return null } if (fieldDefinition.cypherDirective() != null) { return null } - val type = fieldDefinition.type.inner() as? GraphQLObjectType - ?: return null + val type = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandle(type)) { return null } if (fieldDefinition.name != "create${type.name}") { return null } - - val relation = type.relationship() ?: return null - val startIdField = relation.getStartFieldId() ?: return null - val endIdField = relation.getEndFieldId() ?: return null - - return CreateRelationTypeHandler(type, relation, startIdField, endIdField, fieldDefinition, schemaConfig) + return CreateRelationTypeHandler(schemaConfig) } - private fun getRelevantFields(type: GraphQLFieldsContainer): List { + private fun getRelevantFields(type: ImplementingTypeDefinition<*>): List { return type - .relevantFields() + .getScalarFields() .filter { !it.isNativeId() } } - private fun canHandle(type: GraphQLFieldsContainer): Boolean { + private fun canHandle(type: ImplementingTypeDefinition<*>): Boolean { val typeName = type.name if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName)) { return false } - if (type !is GraphQLObjectType) { + if (type is InterfaceTypeDefinition) { return false } val relation = type.relationship() ?: return false - val startIdField = relation.getStartFieldId() - val endIdField = relation.getEndFieldId() + val startIdField = getRelatedIdField(relation, relation.startField) + val endIdField = getRelatedIdField(relation, relation.endField) if (startIdField == null || endIdField == null) { return false } @@ -99,12 +91,49 @@ class CreateRelationTypeHandler private constructor( return true } + + data class RelatedField( + val argumentName: String, + val field: FieldDefinition, + ) + + private fun getRelatedIdField(info: RelationshipInfo>, relFieldName: String?): RelatedField? { + if (relFieldName == null) return null + val relFieldDefinition = info.type.getFieldDefinition(relFieldName) + ?: throw IllegalArgumentException("field $relFieldName does not exists on ${info.typeName}") + + val relType = relFieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> + ?: throw IllegalArgumentException("type ${relFieldDefinition.type.name()} not found") + return relType.fieldDefinitions.filter { it.type.inner().isID() } + .map { RelatedField(normalizeFieldName(relFieldName, it.name), it) } + .firstOrNull() + } + + } + + private fun getRelatedIdField(info: RelationshipInfo, relFieldName: String): RelatedField { + val relFieldDefinition = info.type.getFieldDefinition(relFieldName) + ?: throw IllegalArgumentException("field $relFieldName does not exists on ${info.typeName}") + + val relType = relFieldDefinition.type.inner() as? GraphQLImplementingType + ?: throw IllegalArgumentException("type ${relFieldDefinition.type.name()} not found") + return relType.fieldDefinitions.filter { it.isID() } + .map { RelatedField(normalizeFieldName(relFieldName, it.name), it, relType) } + .firstOrNull() + ?: throw IllegalStateException("Cannot find id field for type ${info.typeName}") + } + + override fun initRelation(fieldDefinition: GraphQLFieldDefinition) { + relation = type.relationship() + ?: throw IllegalStateException("Cannot resolve relationship for type ${type.name}") + startId = getRelatedIdField(relation, relation.startField) + endId = getRelatedIdField(relation, relation.endField) } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { val properties = properties(variable, field.arguments) - val arguments = field.arguments.map { it.name to it }.toMap() + val arguments = field.arguments.associateBy { it.name } val (startNode, startWhere) = getRelationSelect(true, arguments) val (endNode, endWhere) = getRelationSelect(false, arguments) val relName = name(variable) @@ -118,4 +147,12 @@ class CreateRelationTypeHandler private constructor( .returning(relName.project(mapProjection).`as`(field.aliasOrName())) .build() } + + companion object { + private fun normalizeFieldName(relFieldName: String?, name: String): String { + // TODO b/c we need to stay backwards compatible this is not caml case but with underscore + //val filedName = normalizeName(relFieldName, name) + return "${relFieldName}_${name}" + } + } } diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/DeleteRelationHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/DeleteRelationHandler.kt index 6dd47689..e5787f05 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/DeleteRelationHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/DeleteRelationHandler.kt @@ -1,29 +1,29 @@ package org.neo4j.graphql.handler.relation import graphql.language.Field +import graphql.language.ImplementingTypeDefinition import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment -import graphql.schema.GraphQLFieldDefinition -import graphql.schema.GraphQLFieldsContainer +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Statement import org.neo4j.cypherdsl.core.StatementBuilder -import org.neo4j.graphql.* +import org.neo4j.graphql.Cypher +import org.neo4j.graphql.SchemaConfig +import org.neo4j.graphql.aliasOrName /** * This class handles all the logic related to the deletion of relations starting from an existing node. * This includes the augmentation of the delete<Edge>-mutator and the related cypher generation */ -class DeleteRelationHandler private constructor( - type: GraphQLFieldsContainer, - relation: RelationshipInfo, - startId: RelationshipInfo.RelatedField, - endId: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseRelationHandler(type, relation, startId, endId, fieldDefinition, schemaConfig) { +class DeleteRelationHandler private constructor(schemaConfig: SchemaConfig) : BaseRelationHandler("delete", schemaConfig) { + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : BaseRelationFactory("delete", schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { - class Factory(schemaConfig: SchemaConfig) : BaseRelationFactory("delete", schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { if (!canHandleType(type)) { return } @@ -31,24 +31,16 @@ class DeleteRelationHandler private constructor( .filter { canHandleField(it) } .mapNotNull { targetField -> buildFieldDefinition(type, targetField, true) - ?.let { builder -> buildingEnv.addMutationField(builder.build()) } + ?.let { builder -> addMutationField(builder.build()) } } } - override fun createDataFetcher( - sourceType: GraphQLFieldsContainer, - relation: RelationshipInfo, - startIdField: RelationshipInfo.RelatedField, - endIdField: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition - ): DataFetcher { - return DeleteRelationHandler(sourceType, relation, startIdField, endIdField, fieldDefinition, schemaConfig) - } + override fun createDataFetcher(): DataFetcher = DeleteRelationHandler(schemaConfig) } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { - val arguments = field.arguments.map { it.name to it }.toMap() + val arguments = field.arguments.associateBy { it.name } val (startNode, startWhere) = getRelationSelect(true, arguments) val (endNode, endWhere) = getRelationSelect(false, arguments) val relName = org.neo4j.cypherdsl.core.Cypher.name("r") diff --git a/core/src/test/kotlin/org/neo4j/graphql/utils/GraphQLSchemaTestSuite.kt b/core/src/test/kotlin/org/neo4j/graphql/utils/GraphQLSchemaTestSuite.kt index b47efed4..308dbf90 100644 --- a/core/src/test/kotlin/org/neo4j/graphql/utils/GraphQLSchemaTestSuite.kt +++ b/core/src/test/kotlin/org/neo4j/graphql/utils/GraphQLSchemaTestSuite.kt @@ -11,6 +11,7 @@ import graphql.schema.idl.SchemaGenerator import graphql.schema.idl.SchemaParser import graphql.schema.idl.SchemaPrinter import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Assumptions import org.junit.jupiter.api.DynamicNode import org.junit.jupiter.api.DynamicTest @@ -60,6 +61,9 @@ class GraphQLSchemaTestSuite(fileName: String) : AsciiDocTestSuite( if (ignore) { Assumptions.assumeFalse(true, e.message) } else { + if (augmentedSchema == null) { + Assertions.fail(e) + } val actualSchema = SCHEMA_PRINTER.print(augmentedSchema) targetSchemaBlock.adjustedCode = actualSchema + "\n" + // this is added since the SCHEMA_PRINTER is not able to print global directives diff --git a/core/src/test/resources/augmentation-tests.adoc b/core/src/test/resources/augmentation-tests.adoc index 475dc525..96cced25 100644 --- a/core/src/test/resources/augmentation-tests.adoc +++ b/core/src/test/resources/augmentation-tests.adoc @@ -894,8 +894,8 @@ input _Neo4jLocalTimeInput { } input _Neo4jPointDistanceFilter { - distance: Float - point: _Neo4jPointInput + distance: Float! + point: _Neo4jPointInput! } input _Neo4jPointInput { diff --git a/examples/graphql-spring-boot/src/main/kotlin/org/neo4j/graphql/examples/graphqlspringboot/config/Neo4jConfiguration.kt b/examples/graphql-spring-boot/src/main/kotlin/org/neo4j/graphql/examples/graphqlspringboot/config/Neo4jConfiguration.kt index e41cbde9..b9991f1d 100644 --- a/examples/graphql-spring-boot/src/main/kotlin/org/neo4j/graphql/examples/graphqlspringboot/config/Neo4jConfiguration.kt +++ b/examples/graphql-spring-boot/src/main/kotlin/org/neo4j/graphql/examples/graphqlspringboot/config/Neo4jConfiguration.kt @@ -1,6 +1,5 @@ package org.neo4j.graphql.examples.graphqlspringboot.config -import graphql.language.VariableReference import graphql.schema.* import org.neo4j.driver.Driver import org.neo4j.graphql.Cypher