diff --git a/core/src/main/kotlin/org/neo4j/graphql/AugmentationHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/AugmentationHandler.kt index 08afd4f2..f875b2c7 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/AugmentationHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/AugmentationHandler.kt @@ -1,17 +1,331 @@ package org.neo4j.graphql +import graphql.language.* +import graphql.language.TypeDefinition import graphql.schema.DataFetcher -import graphql.schema.GraphQLFieldDefinition -import graphql.schema.GraphQLFieldsContainer +import graphql.schema.idl.ScalarInfo +import graphql.schema.idl.TypeDefinitionRegistry +import org.atteo.evo.inflector.English +import org.neo4j.graphql.handler.projection.ProjectionBase -abstract class AugmentationHandler(val schemaConfig: SchemaConfig) { +/** + * A base class for augmenting a TypeDefinitionRegistry. There a re 2 steps in augmenting the types: + * 1. augmenting the type by creating the relevant query / mutation fields and adding filtering and sorting to the relation fields + * 2. generating a data fetcher based on a field definition. The field may be an augmented field (from step 1) + * but can also be a user specified query / mutation field + */ +abstract class AugmentationHandler( + val schemaConfig: SchemaConfig, + val typeDefinitionRegistry: TypeDefinitionRegistry, + val neo4jTypeDefinitionRegistry: TypeDefinitionRegistry +) { enum class OperationType { QUERY, MUTATION } - open fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) {} + /** + * The 1st step in enhancing a schema. This method creates relevant query / mutation fields and / or adds filtering and sorting to the relation fields of the given type + * @param type the type for which the schema should be enhanced / augmented + */ + open fun augmentType(type: ImplementingTypeDefinition<*>) {} - abstract fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? + /** + * The 2nd step is creating a data fetcher based on a field definition. The field may be an augmented field (from step 1) + * but can also be a user specified query / mutation field + * @param operationType the type of the field + * @param fieldDefinition the filed to create the data fetcher for + * @return a data fetcher for the field or null if not applicable + */ + abstract fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? + protected fun buildFieldDefinition( + prefix: String, + resultType: ImplementingTypeDefinition<*>, + scalarFields: List, + nullableResult: Boolean, + forceOptionalProvider: (field: FieldDefinition) -> Boolean = { false } + ): FieldDefinition.Builder { + var type: Type<*> = TypeName(resultType.name) + if (!nullableResult) { + type = NonNullType(type) + } + return FieldDefinition.newFieldDefinition() + .name("$prefix${resultType.name}") + .inputValueDefinitions(getInputValueDefinitions(scalarFields, forceOptionalProvider)) + .type(type) + } + + protected fun getInputValueDefinitions( + relevantFields: List, + forceOptionalProvider: (field: FieldDefinition) -> Boolean): List { + return relevantFields.map { field -> + var type = getInputType(field.type) + type = if (forceOptionalProvider(field)) { + (type as? NonNullType)?.type ?: type + } else { + type + } + input(field.name, type) + } + } + + protected fun addQueryField(fieldDefinition: FieldDefinition) { + addOperation(typeDefinitionRegistry.queryTypeName(), fieldDefinition) + } + + protected fun addMutationField(fieldDefinition: FieldDefinition) { + addOperation(typeDefinitionRegistry.mutationTypeName(), fieldDefinition) + } + + protected fun isRootType(type: ImplementingTypeDefinition<*>): Boolean { + return type.name == typeDefinitionRegistry.queryTypeName() + || type.name == typeDefinitionRegistry.mutationTypeName() + || type.name == typeDefinitionRegistry.subscriptionTypeName() + } + + /** + * add the given operation to the corresponding rootType + */ + private fun addOperation(rootTypeName: String, fieldDefinition: FieldDefinition) { + val rootType = typeDefinitionRegistry.getType(rootTypeName)?.unwrap() + if (rootType == null) { + typeDefinitionRegistry.add(ObjectTypeDefinition.newObjectTypeDefinition() + .name(rootTypeName) + .fieldDefinition(fieldDefinition) + .build()) + } else { + val existingRootType = (rootType as? ObjectTypeDefinition + ?: throw IllegalStateException("root type $rootTypeName is not an object type but ${rootType.javaClass}")) + if (existingRootType.fieldDefinitions.find { it.name == fieldDefinition.name } != null) { + return // definition already exists, we don't override it + } + typeDefinitionRegistry.remove(rootType) + typeDefinitionRegistry.add(rootType.transform { it.fieldDefinition(fieldDefinition) }) + } + } + + protected fun addFilterType(type: ImplementingTypeDefinition<*>, createdTypes: MutableSet = mutableSetOf()): String { + val filterName = if (schemaConfig.useWhereFilter) type.name + "Where" else "_${type.name}Filter" + if (createdTypes.contains(filterName)) { + return filterName + } + val existingFilterType = typeDefinitionRegistry.getType(filterName).unwrap() + if (existingFilterType != null) { + return (existingFilterType as? InputObjectTypeDefinition)?.name + ?: throw IllegalStateException("Filter type $filterName is already defined but not an input type") + } + createdTypes.add(filterName) + val builder = InputObjectTypeDefinition.newInputObjectDefinition() + .name(filterName) + listOf("AND", "OR", "NOT").forEach { + builder.inputValueDefinition(InputValueDefinition.newInputValueDefinition() + .name(it) + .type(ListType(NonNullType(TypeName(filterName)))) + .build()) + } + type.fieldDefinitions + .filter { it.dynamicPrefix() == null } // TODO currently we do not support filtering on dynamic properties + .forEach { field -> + val typeDefinition = field.type.resolve() + ?: throw IllegalArgumentException("type ${field.type.name()} cannot be resolved") + val filterType = when { + field.type.inner().isNeo4jType() -> getInputType(field.type).name()!! + typeDefinition is ScalarTypeDefinition -> typeDefinition.name + typeDefinition is EnumTypeDefinition -> typeDefinition.name + typeDefinition is ImplementingTypeDefinition -> { + when { + field.type.inner().isNeo4jType() -> typeDefinition.name + else -> addFilterType(typeDefinition, createdTypes) + } + } + else -> throw IllegalArgumentException("${field.type.name()} is neither an object nor an interface") + } + + if (field.isRelationship()) { + RelationOperator.createRelationFilterFields(type, field, filterType, builder) + } else { + FieldOperator.forType(typeDefinition, field.type.inner().isNeo4jType()) + .forEach { op -> builder.addFilterField(op.fieldName(field.name), op.list, filterType, field.description) } + if (typeDefinition.isNeo4jSpatialType() == true) { + val distanceFilterType = getSpatialDistanceFilter(neo4jTypeDefinitionRegistry.getUnwrappedType(filterType) as TypeDefinition<*>) + FieldOperator.forType(distanceFilterType, true) + .forEach { op -> builder.addFilterField(op.fieldName(field.name + NEO4j_POINT_DISTANCE_FILTER_SUFFIX), op.list, NEO4j_POINT_DISTANCE_FILTER) } + } + } + + } + typeDefinitionRegistry.add(builder.build()) + return filterName + } + + private fun getSpatialDistanceFilter(pointType: TypeDefinition<*>): InputObjectTypeDefinition { + return addInputType(NEO4j_POINT_DISTANCE_FILTER, listOf( + input("distance", NonNullType(TypeFloat)), + input("point", NonNullType(TypeName(pointType.name))) + )) + } + + protected fun addOptions(type: ImplementingTypeDefinition<*>): String { + val optionsName = "${type.name}Options" + val optionsType = typeDefinitionRegistry.getType(optionsName)?.unwrap() + if (optionsType != null) { + return (optionsType as? InputObjectTypeDefinition)?.name + ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") + } + val sortTypeName = addSortInputType(type) + val optionsTypeBuilder = InputObjectTypeDefinition.newInputObjectDefinition().name(optionsName) + if (sortTypeName != null) { + optionsTypeBuilder.inputValueDefinition(input( + ProjectionBase.SORT, + ListType(NonNullType(TypeName(sortTypeName))), + "Specify one or more $sortTypeName objects to sort ${English.plural(type.name)} by. The sorts will be applied in the order in which they are arranged in the array.") + ) + } + optionsTypeBuilder + .inputValueDefinition(input(ProjectionBase.LIMIT, TypeInt, "Defines the maximum amount of records returned")) + .inputValueDefinition(input(ProjectionBase.SKIP, TypeInt, "Defines the amount of records to be skipped")) + .build() + typeDefinitionRegistry.add(optionsTypeBuilder.build()) + return optionsName + } + + private fun addSortInputType(type: ImplementingTypeDefinition<*>): String? { + val sortTypeName = "${type.name}Sort" + val sortType = typeDefinitionRegistry.getType(sortTypeName)?.unwrap() + if (sortType != null) { + return (sortType as? InputObjectTypeDefinition)?.name + ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") + } + val relevantFields = type.getScalarFields() + if (relevantFields.isEmpty()) { + return null + } + val builder = InputObjectTypeDefinition.newInputObjectDefinition() + .name(sortTypeName) + .description("Fields to sort ${type.name}s by. The order in which sorts are applied is not guaranteed when specifying many fields in one MovieSort object.".asDescription()) + for (relevantField in relevantFields) { + builder.inputValueDefinition(input(relevantField.name, TypeName("SortDirection"))) + } + typeDefinitionRegistry.add(builder.build()) + return sortTypeName + } + + protected fun addOrdering(type: ImplementingTypeDefinition<*>): String? { + val orderingName = "_${type.name}Ordering" + var existingOrderingType = typeDefinitionRegistry.getType(orderingName)?.unwrap() + if (existingOrderingType != null) { + return (existingOrderingType as? EnumTypeDefinition)?.name + ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") + } + val sortingFields = type.getScalarFields() + if (sortingFields.isEmpty()) { + return null + } + existingOrderingType = EnumTypeDefinition.newEnumTypeDefinition() + .name(orderingName) + .enumValueDefinitions(sortingFields.flatMap { fd -> + listOf("_asc", "_desc") + .map { + EnumValueDefinition + .newEnumValueDefinition() + .name(fd.name + it) + .build() + } + }) + .build() + typeDefinitionRegistry.add(existingOrderingType) + return orderingName + } + + private fun addInputType(inputName: String, relevantFields: List): InputObjectTypeDefinition { + var inputType = typeDefinitionRegistry.getType(inputName)?.unwrap() + if (inputType != null) { + return inputType as? InputObjectTypeDefinition + ?: throw IllegalStateException("Filter type $inputName is already defined but not an input type") + } + inputType = getInputType(inputName, relevantFields) + typeDefinitionRegistry.add(inputType) + return inputType + } + + private fun getInputType(inputName: String, relevantFields: List): InputObjectTypeDefinition { + return InputObjectTypeDefinition.newInputObjectDefinition() + .name(inputName) + .inputValueDefinitions(relevantFields) + .build() + } + + private fun getInputType(type: Type<*>): Type<*> { + if (type.inner().isNeo4jType()) { + return neo4jTypeDefinitions + .find { it.typeDefinition == type.name() } + ?.let { TypeName(it.inputDefinition) } + ?: throw IllegalArgumentException("Cannot find input type for ${type.name()}") + } + return type + } + + private fun getTypeFromAnyRegistry(name: String?): TypeDefinition<*>? = typeDefinitionRegistry.getUnwrappedType(name) + ?: neo4jTypeDefinitionRegistry.getUnwrappedType(name) + + fun ImplementingTypeDefinition<*>.relationship(): RelationshipInfo>? = RelationshipInfo.create(this, neo4jTypeDefinitionRegistry) + + fun ImplementingTypeDefinition<*>.getScalarFields(): List = fieldDefinitions + .filter { it.type.inner().isScalar() || it.type.inner().isNeo4jType() } + .sortedByDescending { it.type.inner().isID() } + + fun ImplementingTypeDefinition<*>.getFieldDefinition(name: String) = this.fieldDefinitions.find { it.name == name } + fun ImplementingTypeDefinition<*>.getIdField() = this.fieldDefinitions.find { it.type.inner().isID() } + + fun Type<*>.resolve(): TypeDefinition<*>? = getTypeFromAnyRegistry(name()) + fun Type<*>.isScalar(): Boolean = resolve() is ScalarTypeDefinition + fun Type<*>.isNeo4jType(): Boolean = name() + ?.takeIf { + !ScalarInfo.GRAPHQL_SPECIFICATION_SCALARS_DEFINITIONS.containsKey(it) + && it.startsWith("_Neo4j") // TODO remove this check by refactoring neo4j input types + } + ?.let { neo4jTypeDefinitionRegistry.getUnwrappedType(it) } != null + + fun Type<*>.isID(): Boolean = name() == "ID" + + + fun FieldDefinition.isNativeId(): Boolean = name == ProjectionBase.NATIVE_ID + fun FieldDefinition.dynamicPrefix(): String? = + getDirectiveArgument(DirectiveConstants.DYNAMIC, DirectiveConstants.DYNAMIC_PREFIX, null) + fun FieldDefinition.isRelationship(): Boolean = + !type.inner().isNeo4jType() && type.resolve() is ImplementingTypeDefinition<*> + + + fun TypeDefinitionRegistry.getUnwrappedType(name: String?): TypeDefinition>? = getType(name)?.unwrap() + + fun DirectivesContainer<*>.cypherDirective(): CypherDirective? = if (hasDirective(DirectiveConstants.CYPHER)) { + CypherDirective( + getMandatoryDirectiveArgument(DirectiveConstants.CYPHER, DirectiveConstants.CYPHER_STATEMENT), + getMandatoryDirectiveArgument(DirectiveConstants.CYPHER, DirectiveConstants.CYPHER_PASS_THROUGH, false) + ) + } else { + null + } + + fun DirectivesContainer<*>.hasDirective(name: String): Boolean = getDirective(name) != null + + fun DirectivesContainer<*>.getDirectiveArgument(directiveName: String, argumentName: String, defaultValue: T? = null): T? = + getDirectiveArgument(neo4jTypeDefinitionRegistry, directiveName, argumentName, defaultValue) + + private fun DirectivesContainer<*>.getMandatoryDirectiveArgument(directiveName: String, argumentName: String, defaultValue: T? = null): T = + getDirectiveArgument(directiveName, argumentName, defaultValue) + ?: throw IllegalStateException("No default value for @${directiveName}::$argumentName") + + fun input(name: String, type: Type<*>, description: String? = null): InputValueDefinition { + val input = InputValueDefinition + .newInputValueDefinition() + .name(name) + .type(type) + if (description != null) { + input.description(description.asDescription()) + } + return input + .build() + } } diff --git a/core/src/main/kotlin/org/neo4j/graphql/BuildingEnv.kt b/core/src/main/kotlin/org/neo4j/graphql/BuildingEnv.kt deleted file mode 100644 index e978a515..00000000 --- a/core/src/main/kotlin/org/neo4j/graphql/BuildingEnv.kt +++ /dev/null @@ -1,280 +0,0 @@ -package org.neo4j.graphql - -import graphql.Scalars -import graphql.schema.* -import org.atteo.evo.inflector.English -import org.neo4j.graphql.handler.projection.ProjectionBase - -class BuildingEnv( - val types: MutableMap, - private val sourceSchema: GraphQLSchema, - val schemaConfig: SchemaConfig -) { - - private val typesForRelation = types.values - .filterIsInstance() - .filter { it.getDirective(DirectiveConstants.RELATION) != null } - .map { it.getDirectiveArgument(DirectiveConstants.RELATION, DirectiveConstants.RELATION_NAME, null)!! to it.name } - .toMap() - - fun buildFieldDefinition( - prefix: String, - resultType: GraphQLOutputType, - scalarFields: List, - nullableResult: Boolean, - forceOptionalProvider: (field: GraphQLFieldDefinition) -> Boolean = { false } - ): GraphQLFieldDefinition.Builder { - var type: GraphQLOutputType = resultType - if (!nullableResult) { - type = GraphQLNonNull(type) - } - return GraphQLFieldDefinition.newFieldDefinition() - .name("$prefix${resultType.name()}") - .arguments(getInputValueDefinitions(scalarFields, forceOptionalProvider)) - .type(type.ref() as GraphQLOutputType) - } - - fun getInputValueDefinitions( - relevantFields: List, - forceOptionalProvider: (field: GraphQLFieldDefinition) -> Boolean): List { - return relevantFields.map { field -> - var type = field.type as GraphQLType - type = getInputType(type) - type = if (forceOptionalProvider(field)) { - (type as? GraphQLNonNull)?.wrappedType ?: type - } else { - type - } - input(field.name, type) - } - } - - fun addQueryField(fieldDefinition: GraphQLFieldDefinition) { - addOperation(sourceSchema.queryTypeName(), fieldDefinition) - } - - fun addMutationField(fieldDefinition: GraphQLFieldDefinition) { - addOperation(sourceSchema.mutationTypeName(), fieldDefinition) - } - - /** - * add the given operation to the corresponding rootType - */ - private fun addOperation(rootTypeName: String, fieldDefinition: GraphQLFieldDefinition) { - val rootType = types[rootTypeName] - types[rootTypeName] = if (rootType == null) { - val builder = GraphQLObjectType.newObject() - builder.name(rootTypeName) - .field(fieldDefinition) - .build() - } else { - val existingRootType = (rootType as? GraphQLObjectType - ?: throw IllegalStateException("root type $rootTypeName is not an object type but ${rootType.javaClass}")) - if (existingRootType.getFieldDefinition(fieldDefinition.name) != null) { - return // definition already exists, we don't override it - } - existingRootType - .transform { builder -> builder.field(fieldDefinition) } - } - } - - fun addFilterType(type: GraphQLFieldsContainer, createdTypes: MutableSet = mutableSetOf()): String { - val filterName = if (schemaConfig.useWhereFilter) type.name + "Where" else "_${type.name}Filter" - if (createdTypes.contains(filterName)) { - return filterName - } - val existingFilterType = types[filterName] - if (existingFilterType != null) { - return (existingFilterType as? GraphQLInputType)?.name() - ?: throw IllegalStateException("Filter type $filterName is already defined but not an input type") - } - createdTypes.add(filterName) - val builder = GraphQLInputObjectType.newInputObject() - .name(filterName) - listOf("AND", "OR", "NOT").forEach { - builder.field(GraphQLInputObjectField.newInputObjectField() - .name(it) - .type(GraphQLList(GraphQLNonNull(GraphQLTypeReference(filterName))))) - } - type.fieldDefinitions - .filter { it.dynamicPrefix() == null } // TODO currently we do not support filtering on dynamic properties - .forEach { field -> - val typeDefinition = field.type.inner() - val filterType = when { - typeDefinition.isNeo4jType() -> getInputType(typeDefinition).requiredName() - typeDefinition.isScalar() -> typeDefinition.innerName() - typeDefinition is GraphQLEnumType -> typeDefinition.innerName() - else -> addFilterType(getInnerFieldsContainer(typeDefinition), createdTypes) - } - - if (field.isRelationship()) { - RelationOperator.createRelationFilterFields(type, field, filterType, builder) - } else { - val graphQLType = types[filterType] ?: typeDefinition - FieldOperator.forType(graphQLType) - .forEach { op -> builder.addFilterField(op.fieldName(field.name), op.list, filterType, field.description) } - if (graphQLType.isNeo4jSpatialType()) { - val distanceFilterType = getSpatialDistanceFilter(graphQLType) - FieldOperator.forType(distanceFilterType) - .forEach { op -> builder.addFilterField(op.fieldName(field.name + NEO4j_POINT_DISTANCE_FILTER_SUFFIX), op.list, NEO4j_POINT_DISTANCE_FILTER) } - } - } - - } - types[filterName] = builder.build() - return filterName - } - - private fun getSpatialDistanceFilter(pointType: GraphQLType): GraphQLInputType { - return addInputType(NEO4j_POINT_DISTANCE_FILTER, listOf( - GraphQLFieldDefinition.newFieldDefinition().name("distance").type(GraphQLNonNull(Scalars.GraphQLFloat)).build(), - GraphQLFieldDefinition.newFieldDefinition().name("point").type(GraphQLNonNull(pointType)).build() - )) - } - - fun addOptions(type: GraphQLFieldsContainer): String { - val optionsName = "${type.name}Options" - val optionsType = types[optionsName] - if (optionsType != null) { - return (optionsType as? GraphQLInputType)?.requiredName() - ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") - } - val sortTypeName = addSortInputType(type) - val optionsTypeBuilder = GraphQLInputObjectType.newInputObject().name(optionsName) - if (sortTypeName != null) { - optionsTypeBuilder.field(GraphQLInputObjectField.newInputObjectField() - .name(ProjectionBase.SORT) - .type(GraphQLList(GraphQLNonNull(GraphQLTypeReference(sortTypeName)))) - .description("Specify one or more $sortTypeName objects to sort ${English.plural(type.name)} by. The sorts will be applied in the order in which they are arranged in the array.") - .build()) - } - optionsTypeBuilder.field(GraphQLInputObjectField.newInputObjectField() - .name(ProjectionBase.LIMIT) - .type(Scalars.GraphQLInt) - .description("Defines the maximum amount of records returned") - .build()) - .field(GraphQLInputObjectField.newInputObjectField() - .name(ProjectionBase.SKIP) - .type(Scalars.GraphQLInt) - .description("Defines the amount of records to be skipped") - .build()) - .build() - types[optionsName] = optionsTypeBuilder.build() - return optionsName - } - - private fun addSortInputType(type: GraphQLFieldsContainer): String? { - val sortTypeName = "${type.name}Sort" - val sortType = types[sortTypeName] - if (sortType != null) { - return (sortType as? GraphQLInputType)?.requiredName() - ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") - } - val relevantFields = type.relevantFields() - if (relevantFields.isEmpty()) { - return null - } - val builder = GraphQLInputObjectType.newInputObject() - .name(sortTypeName) - .description("Fields to sort ${type.name}s by. The order in which sorts are applied is not guaranteed when specifying many fields in one MovieSort object.") - for (relevantField in relevantFields) { - builder.field(GraphQLInputObjectField.newInputObjectField() - .name(relevantField.name) - .type(GraphQLTypeReference("SortDirection")) - .build()) - } - types[sortTypeName] = builder.build() - return sortTypeName - } - - fun addOrdering(type: GraphQLFieldsContainer): String? { - val orderingName = "_${type.name}Ordering" - var existingOrderingType = types[orderingName] - if (existingOrderingType != null) { - return (existingOrderingType as? GraphQLInputType)?.requiredName() - ?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type") - } - val sortingFields = type.fieldDefinitions - .filter { it.type.isScalar() || it.isNeo4jType() } - .sortedByDescending { it.isID() } - if (sortingFields.isEmpty()) { - return null - } - existingOrderingType = GraphQLEnumType.newEnum() - .name(orderingName) - .values(sortingFields.flatMap { fd -> - listOf("_asc", "_desc") - .map { - GraphQLEnumValueDefinition - .newEnumValueDefinition() - .name(fd.name + it) - .value(fd.name + it) - .build() - } - }) - .build() - types[orderingName] = existingOrderingType - return orderingName - } - - fun addInputType(inputName: String, relevantFields: List): GraphQLInputType { - var inputType = types[inputName] - if (inputType != null) { - return inputType as? GraphQLInputType - ?: throw IllegalStateException("Filter type $inputName is already defined but not an input type") - } - inputType = getInputType(inputName, relevantFields) - types[inputName] = inputType - return inputType - } - - fun getTypeForRelation(nameOfRelation: String): GraphQLObjectType? { - return typesForRelation[nameOfRelation]?.let { types[it] } as? GraphQLObjectType - } - - private fun getInputType(inputName: String, relevantFields: List): GraphQLInputObjectType { - return GraphQLInputObjectType.newInputObject() - .name(inputName) - .fields(getInputValueDefinitions(relevantFields)) - .build() - } - - private fun getInputValueDefinitions(relevantFields: List): List { - return relevantFields.map { - // just make evrything optional - val type = (it.type as? GraphQLNonNull)?.wrappedType ?: it.type - GraphQLInputObjectField - .newInputObjectField() - .name(it.name) - .description(it.description) - .type(getInputType(type).ref() as GraphQLInputType) - .build() - } - } - - private fun getInnerFieldsContainer(type: GraphQLType): GraphQLFieldsContainer { - var innerType = type.inner() - if (innerType is GraphQLTypeReference) { - innerType = types[innerType.name] - ?: throw IllegalArgumentException("${innerType.name} is unknown") - } - return innerType as? GraphQLFieldsContainer - ?: throw IllegalArgumentException("${innerType.name()} is neither an object nor an interface") - } - - private fun getInputType(type: GraphQLType): GraphQLInputType { - val inner = type.inner() - if (inner is GraphQLInputType) { - return type as GraphQLInputType - } - if (inner.isNeo4jType()) { - return neo4jTypeDefinitions - .find { it.typeDefinition == inner.name() } - ?.let { types[it.inputDefinition] } as? GraphQLInputType - ?: throw IllegalArgumentException("Cannot find input type for ${inner.name()}") - } - return type as? GraphQLInputType - ?: throw IllegalArgumentException("${type.name()} is not allowed for input") - } - -} diff --git a/core/src/main/kotlin/org/neo4j/graphql/ExtensionFunctions.kt b/core/src/main/kotlin/org/neo4j/graphql/ExtensionFunctions.kt index 720451b6..10514b70 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/ExtensionFunctions.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/ExtensionFunctions.kt @@ -1,23 +1,14 @@ package org.neo4j.graphql +import graphql.language.Description import graphql.language.VariableReference -import graphql.schema.GraphQLArgument -import graphql.schema.GraphQLInputType -import graphql.schema.GraphQLType import org.neo4j.cypherdsl.core.* +import java.util.* fun Iterable.joinNonEmpty(separator: CharSequence = ", ", prefix: CharSequence = "", postfix: CharSequence = "", limit: Int = -1, truncated: CharSequence = "...", transform: ((T) -> CharSequence)? = null): String { return if (iterator().hasNext()) joinTo(StringBuilder(), separator, prefix, postfix, limit, truncated, transform).toString() else "" } -fun input(name: String, type: GraphQLType): GraphQLArgument { - return GraphQLArgument - .newArgument() - .name(name) - .type((type.ref() as? GraphQLInputType) - ?: throw IllegalArgumentException("${type.innerName()} is not allowed for input")).build() -} - fun queryParameter(value: Any?, vararg parts: String?): Parameter { val name = when (value) { is VariableReference -> value.name @@ -36,3 +27,7 @@ fun PropertyContainer.id(): FunctionInvocation = when (this) { } fun String.toCamelCase(): String = Regex("[\\W_]([a-z])").replace(this) { it.groupValues[1].toUpperCase() } + +fun Optional.unwrap(): T? = orElse(null) + +fun String.asDescription() = Description(this, null, this.contains("\n")) diff --git a/core/src/main/kotlin/org/neo4j/graphql/GraphQLExtensions.kt b/core/src/main/kotlin/org/neo4j/graphql/GraphQLExtensions.kt index c7267d50..0b90c8c0 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/GraphQLExtensions.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/GraphQLExtensions.kt @@ -2,9 +2,9 @@ package org.neo4j.graphql import graphql.Scalars import graphql.language.* +import graphql.language.TypeDefinition import graphql.schema.* -import org.neo4j.cypherdsl.core.Node -import org.neo4j.cypherdsl.core.Relationship +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.SymbolicName import org.neo4j.graphql.DirectiveConstants.Companion.CYPHER import org.neo4j.graphql.DirectiveConstants.Companion.CYPHER_PASS_THROUGH @@ -13,21 +13,25 @@ import org.neo4j.graphql.DirectiveConstants.Companion.DYNAMIC import org.neo4j.graphql.DirectiveConstants.Companion.DYNAMIC_PREFIX import org.neo4j.graphql.DirectiveConstants.Companion.PROPERTY import org.neo4j.graphql.DirectiveConstants.Companion.PROPERTY_NAME -import org.neo4j.graphql.DirectiveConstants.Companion.RELATION_DIRECTION -import org.neo4j.graphql.DirectiveConstants.Companion.RELATION_FROM import org.neo4j.graphql.DirectiveConstants.Companion.RELATION_NAME import org.neo4j.graphql.DirectiveConstants.Companion.RELATION_TO import org.neo4j.graphql.handler.projection.ProjectionBase import java.math.BigDecimal import java.math.BigInteger -fun Type>.name(): String? = if (this.inner() is TypeName) (this.inner() as TypeName).name else null -fun Type>.inner(): Type> = when (this) { +fun Type<*>.name(): String? = if (this.inner() is TypeName) (this.inner() as TypeName).name else null +fun Type<*>.inner(): Type<*> = when (this) { is ListType -> this.type.inner() is NonNullType -> this.type.inner() else -> this } +fun Type<*>.isList(): Boolean = when (this) { + is ListType -> true + is NonNullType -> type.isList() + else -> false +} + fun GraphQLType.inner(): GraphQLType = when (this) { is GraphQLList -> this.wrappedType.inner() is GraphQLNonNull -> this.wrappedType.inner() @@ -38,19 +42,20 @@ fun GraphQLType.name(): String? = (this as? GraphQLNamedType)?.name fun GraphQLType.requiredName(): String = requireNotNull(name()) { "name is required but cannot be determined for " + this.javaClass } fun GraphQLType.isList() = this is GraphQLList || (this is GraphQLNonNull && this.wrappedType is GraphQLList) -fun GraphQLType.isScalar() = this.inner().let { it is GraphQLScalarType || it.innerName().startsWith("_Neo4j") } fun GraphQLType.isNeo4jType() = this.innerName().startsWith("_Neo4j") + fun GraphQLType.isNeo4jSpatialType() = this.innerName().startsWith("_Neo4jPoint") +fun TypeDefinition<*>.isNeo4jSpatialType() = this.name.startsWith("_Neo4jPoint") + fun GraphQLFieldDefinition.isNeo4jType(): Boolean = this.type.isNeo4jType() fun GraphQLFieldDefinition.isRelationship() = !type.isNeo4jType() && this.type.inner().let { it is GraphQLFieldsContainer } -fun GraphQLDirectiveContainer.isRelationType() = getDirective(DirectiveConstants.RELATION) != null fun GraphQLFieldsContainer.isRelationType() = (this as? GraphQLDirectiveContainer)?.getDirective(DirectiveConstants.RELATION) != null -fun GraphQLFieldsContainer.relationshipFor(name: String): RelationshipInfo? { +fun GraphQLFieldsContainer.relationshipFor(name: String): RelationshipInfo? { val field = getFieldDefinition(name) ?: throw IllegalArgumentException("$name is not defined on ${this.name}") - val fieldObjectType = field.type.inner() as? GraphQLFieldsContainer ?: return null + val fieldObjectType = field.type.inner() as? GraphQLImplementingType ?: return null val (relDirective, inverse) = if (isRelationType()) { val typeName = this.name @@ -67,7 +72,7 @@ fun GraphQLFieldsContainer.relationshipFor(name: String): RelationshipInfo? { ?: throw IllegalStateException("Field $field needs an @relation directive") } - val relInfo = relDetails(fieldObjectType, relDirective) + val relInfo = RelationshipInfo.create(fieldObjectType, relDirective) return if (inverse) relInfo.copy(direction = relInfo.direction.invert(), startField = relInfo.endField, endField = relInfo.startField) else relInfo } @@ -92,20 +97,7 @@ fun GraphQLFieldsContainer.label(): String = when { else -> name } - -fun GraphQLFieldsContainer.relevantFields() = fieldDefinitions - .filter { it.type.isScalar() || it.isNeo4jType() } - .sortedByDescending { it.isID() } - -fun GraphQLFieldsContainer.relationship(): RelationshipInfo? { - val relDirective = (this as? GraphQLDirectiveContainer)?.getDirective(DirectiveConstants.RELATION) ?: return null - val relType = relDirective.getArgument(RELATION_NAME, "")!! - val startField = relDirective.getMandatoryArgument(RELATION_FROM) - val endField = relDirective.getMandatoryArgument(RELATION_TO) - val direction = relDirective.getArgument(RELATION_DIRECTION)?.let { RelationDirection.valueOf(it) } - ?: RelationDirection.OUT - return RelationshipInfo(this, relType, direction, startField, endField) -} +fun GraphQLFieldsContainer.relationship(): RelationshipInfo? = RelationshipInfo.create(this) fun GraphQLType.ref(): GraphQLType = when (this) { is GraphQLNonNull -> GraphQLNonNull(this.wrappedType.ref()) @@ -116,63 +108,6 @@ fun GraphQLType.ref(): GraphQLType = when (this) { else -> GraphQLTypeReference(name()) } -fun relDetails(type: GraphQLFieldsContainer, relDirective: GraphQLDirective): RelationshipInfo { - val relType = relDirective.getArgument(RELATION_NAME, "")!! - val direction = relDirective.getArgument(RELATION_DIRECTION, null) - ?.let { RelationDirection.valueOf(it) } - ?: RelationDirection.OUT - - return RelationshipInfo(type, - relType, - direction, - relDirective.getMandatoryArgument(RELATION_FROM), - relDirective.getMandatoryArgument(RELATION_TO) - ) -} - -data class RelationshipInfo( - val type: GraphQLFieldsContainer, - val relType: String, - val direction: RelationDirection, - val startField: String, - val endField: String -) { - data class RelatedField( - val argumentName: String, - val field: GraphQLFieldDefinition, - val declaringType: GraphQLFieldsContainer - ) - - val typeName: String get() = this.type.name - - fun getStartFieldId() = getRelatedIdField(this.startField) - - fun getEndFieldId() = getRelatedIdField(this.endField) - - private fun getRelatedIdField(relFieldName: String?): RelatedField? { - if (relFieldName == null) return null - val relFieldDefinition = type.getFieldDefinition(relFieldName) - ?: throw IllegalArgumentException("field $relFieldName does not exists on ${type.innerName()}") - val relType = relFieldDefinition.type.inner() as? GraphQLFieldsContainer - ?: throw IllegalArgumentException("type ${relFieldDefinition.type.innerName()} not found") - return relType.fieldDefinitions.filter { it.isID() } - .map { - // TODO b/c we need to stay backwards kompatible this is not caml case but with underscore - //val filedName = normalizeName(relFieldName, it.name) - val filedName = "${relFieldName}_${it.name}" - RelatedField(filedName, it, relType) - } - .firstOrNull() - } - - fun createRelation(start: Node, end: Node): Relationship = - when (this.direction) { - RelationDirection.IN -> start.relationshipFrom(end, this.relType) - RelationDirection.OUT -> start.relationshipTo(end, this.relType) - RelationDirection.BOTH -> start.relationshipBetween(end, this.relType) - } -} - fun Field.aliasOrName(): String = (this.alias ?: this.name) fun Field.contextualize(variable: String) = variable + this.aliasOrName().capitalize() fun Field.contextualize(variable: SymbolicName) = variable.value + this.aliasOrName().capitalize() @@ -189,6 +124,23 @@ fun GraphQLType.getInnerFieldsContainer() = inner() as? GraphQLFieldsContainer fun GraphQLDirectiveContainer.getDirectiveArgument(directiveName: String, argumentName: String, defaultValue: T?): T? = getDirective(directiveName)?.getArgument(argumentName, defaultValue) ?: defaultValue +@Suppress("UNCHECKED_CAST") +fun DirectivesContainer<*>.getDirectiveArgument(typeRegistry: TypeDefinitionRegistry, directiveName: String, argumentName: String, defaultValue: T? = null): T? { + return (getDirective(directiveName) ?: return defaultValue) + .getArgument(argumentName)?.value?.toJavaValue() as T? + ?: typeRegistry.getDirectiveDefinition(directiveName) + ?.unwrap() + ?.inputValueDefinitions + ?.find { inputValueDefinition -> inputValueDefinition.name == argumentName } + ?.defaultValue?.toJavaValue() as T? + ?: defaultValue +} + +@Suppress("UNCHECKED_CAST") +fun DirectivesContainer<*>.getMandatoryDirectiveArgument(typeRegistry: TypeDefinitionRegistry, directiveName: String, argumentName: String, defaultValue: T? = null): T = + getDirectiveArgument(typeRegistry, directiveName, argumentName, defaultValue) + ?: throw IllegalStateException("No default value for @${directiveName}::$argumentName") + fun GraphQLDirective.getMandatoryArgument(argumentName: String, defaultValue: T? = null): T = this.getArgument(argumentName, defaultValue) ?: throw IllegalStateException(argumentName + " is required for @${this.name}") @@ -201,10 +153,12 @@ fun GraphQLDirective.getArgument(argumentName: String, defaultValue: T? = nu ?: throw IllegalStateException("No default value for @${this.name}::$argumentName") } -fun GraphQLFieldDefinition.cypherDirective() :CypherDirective?= getDirective(CYPHER)?.let { CypherDirective( - it.getMandatoryArgument(CYPHER_STATEMENT), - it.getMandatoryArgument(CYPHER_PASS_THROUGH, false) -) } +fun GraphQLFieldDefinition.cypherDirective(): CypherDirective? = getDirective(CYPHER)?.let { + CypherDirective( + it.getMandatoryArgument(CYPHER_STATEMENT), + it.getMandatoryArgument(CYPHER_PASS_THROUGH, false) + ) +} data class CypherDirective(val statement: String, val passThrough: Boolean) @@ -222,7 +176,7 @@ fun Value<*>.toJavaValue(): Any? = when (this) { is IntValue -> this.value.longValueExact() is VariableReference -> this is ArrayValue -> this.values.map { it.toJavaValue() }.toList() - is ObjectValue -> this.objectFields.map { it.name to it.value.toJavaValue() }.toMap() + is ObjectValue -> this.objectFields.associate { it.name to it.value.toJavaValue() } else -> throw IllegalStateException("Unhandled value $this") } @@ -230,22 +184,25 @@ fun GraphQLFieldDefinition.isID() = this.type.inner() == Scalars.GraphQLID fun GraphQLFieldDefinition.isNativeId() = this.name == ProjectionBase.NATIVE_ID fun GraphQLFieldsContainer.getIdField() = this.fieldDefinitions.find { it.isID() } -fun GraphQLInputObjectType.Builder.addFilterField(fieldName: String, isList: Boolean, filterType: String, description: String? = null) { - val wrappedType: GraphQLInputType = when { - isList -> GraphQLList(GraphQLTypeReference(filterType)) - else -> GraphQLTypeReference(filterType) +fun InputObjectTypeDefinition.Builder.addFilterField(fieldName: String, isList: Boolean, filterType: String, description: Description? = null) { + val wrappedType: Type<*> = when { + isList -> ListType(TypeName(filterType)) + else -> TypeName(filterType) } - val inputField = GraphQLInputObjectField.newInputObjectField() + val inputField = InputValueDefinition.newInputValueDefinition() .name(fieldName) - .description(description) .type(wrappedType) - this.field(inputField) -} + if (description != null) { + inputField.description(description) + } + this.inputValueDefinition(inputField.build()) +} -fun GraphQLSchema.queryTypeName() = this.queryType?.name ?: "Query" -fun GraphQLSchema.mutationTypeName() = this.mutationType?.name ?: "Mutation" -fun GraphQLSchema.subscriptionTypeName() = this.subscriptionType?.name ?: "Subscription" +fun TypeDefinitionRegistry.queryTypeName() = this.getOperationType("query") ?: "Query" +fun TypeDefinitionRegistry.mutationTypeName() = this.getOperationType("mutation") ?: "Mutation" +fun TypeDefinitionRegistry.subscriptionTypeName() = this.getOperationType("subscription") ?: "Subscription" +fun TypeDefinitionRegistry.getOperationType(name: String) = this.schemaDefinition().unwrap()?.operationTypeDefinitions?.firstOrNull { it.name == name }?.typeName?.name fun Any?.asGraphQLValue(): Value<*> = when (this) { null -> NullValue.newNullValue().build() @@ -261,3 +218,13 @@ fun Any?.asGraphQLValue(): Value<*> = when (this) { is String -> StringValue.newStringValue(this).build() else -> throw IllegalStateException("Cannot convert ${this.javaClass.name} into an graphql type") } + +fun DataFetchingEnvironment.typeAsContainer() = this.fieldDefinition.type.inner() as? GraphQLFieldsContainer + ?: throw IllegalStateException("expect type of field ${this.logField()} to be GraphQLFieldsContainer, but was ${this.fieldDefinition.type.name()}") + +fun DataFetchingEnvironment.logField() = "${this.parentType.name()}.${this.fieldDefinition.name}" + +val TypeInt = TypeName("Int") +val TypeFloat = TypeName("Float") +val TypeBoolean = TypeName("Boolean") +val TypeID = TypeName("ID") diff --git a/core/src/main/kotlin/org/neo4j/graphql/Predicates.kt b/core/src/main/kotlin/org/neo4j/graphql/Predicates.kt index d6d3cf8d..c0bff5e2 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/Predicates.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/Predicates.kt @@ -1,11 +1,13 @@ package org.neo4j.graphql -import graphql.Scalars -import graphql.language.NullValue -import graphql.language.ObjectValue -import graphql.language.Value -import graphql.schema.* -import org.neo4j.cypherdsl.core.* +import graphql.language.* +import graphql.language.TypeDefinition +import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLFieldsContainer +import org.neo4j.cypherdsl.core.Condition +import org.neo4j.cypherdsl.core.Expression +import org.neo4j.cypherdsl.core.Property +import org.neo4j.cypherdsl.core.PropertyContainer import org.slf4j.LoggerFactory typealias CypherDSL = org.neo4j.cypherdsl.core.Cypher @@ -35,13 +37,13 @@ enum class FieldOperator( C("_contains", "CONTAINS", { lhs, rhs -> lhs.contains(rhs) }), SW("_starts_with", "STARTS WITH", { lhs, rhs -> lhs.startsWith(rhs) }), EW("_ends_with", "ENDS WITH", { lhs, rhs -> lhs.endsWith(rhs) }), - MATCHES("_matches", "=~", {lhs, rhs -> lhs.matches(rhs) }), + MATCHES("_matches", "=~", { lhs, rhs -> lhs.matches(rhs) }), DISTANCE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX, "=", { lhs, rhs -> lhs.isEqualTo(rhs) }, distance = true), DISTANCE_LT(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_lt", "<", { lhs, rhs -> lhs.lt(rhs) }, distance = true), - DISTANCE_LTE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_lte", "<=", { lhs, rhs -> lhs.lte(rhs) }, distance = true), - DISTANCE_GT(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_gt", ">", { lhs, rhs -> lhs.gt(rhs) }, distance = true), + DISTANCE_LTE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_lte", "<=", { lhs, rhs -> lhs.lte(rhs) }, distance = true), + DISTANCE_GT(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_gt", ">", { lhs, rhs -> lhs.gt(rhs) }, distance = true), DISTANCE_GTE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_gte", ">=", { lhs, rhs -> lhs.gte(rhs) }, distance = true); val list = op == "IN" @@ -68,14 +70,14 @@ enum class FieldOperator( private fun resolveNeo4jTypeConditions(variablePrefix: String, queriedField: String, propertyContainer: PropertyContainer, field: GraphQLFieldDefinition, value: ObjectValue, suffix: String?): List { val neo4jTypeConverter = getNeo4jTypeConverter(field) val conditions = mutableListOf() - if (distance){ + if (distance) { val parameter = queryParameter(value, variablePrefix, queriedField, suffix) conditions += (neo4jTypeConverter as Neo4jPointConverter).createDistanceCondition( propertyContainer.property(field.propertyName()), parameter, conditionCreator ) - } else { + } else { value.objectFields.forEachIndexed { index, objectField -> val parameter = queryParameter(value, variablePrefix, queriedField, if (value.objectFields.size > 1) "And${index + 1}" else null, suffix, objectField.name) .withValue(objectField.value.toJavaValue()) @@ -108,18 +110,18 @@ enum class FieldOperator( } } - fun forType(type: GraphQLType): List = + fun forType(type: TypeDefinition<*>, isNeo4jType: Boolean): List = when { - type == Scalars.GraphQLBoolean -> listOf(EQ, NEQ) - type.innerName() == NEO4j_POINT_DISTANCE_FILTER -> listOf(EQ, LT, LTE, GT, GTE) + type.name == TypeBoolean.name -> listOf(EQ, NEQ) + type.name == NEO4j_POINT_DISTANCE_FILTER -> listOf(EQ, LT, LTE, GT, GTE) type.isNeo4jSpatialType() -> listOf(EQ, NEQ) - type.isNeo4jType() -> listOf(EQ, NEQ, IN, NIN) - type is GraphQLFieldsContainer || type is GraphQLInputObjectType -> throw IllegalArgumentException("This operators are not for relations, use the RelationOperator instead") - type is GraphQLEnumType -> listOf(EQ, NEQ, IN, NIN) + isNeo4jType -> listOf(EQ, NEQ, IN, NIN) + type is ImplementingTypeDefinition<*> -> throw IllegalArgumentException("This operators are not for relations, use the RelationOperator instead") + type is EnumTypeDefinition -> listOf(EQ, NEQ, IN, NIN) // todo list types - !type.isScalar() -> listOf(EQ, NEQ, IN, NIN) + type !is ScalarTypeDefinition -> listOf(EQ, NEQ, IN, NIN) else -> listOf(EQ, NEQ, IN, NIN, LT, LTE, GT, GTE) + - if (type.name() == "String" || type.name() == "ID") listOf(C, NC, SW, NSW, EW, NEW, MATCHES) else emptyList() + if (type.name == "String" || type.name == "ID") listOf(C, NC, SW, NSW, EW, NEW, MATCHES) else emptyList() } } @@ -183,11 +185,11 @@ enum class RelationOperator(val suffix: String, val op: String) { companion object { private val LOGGER = LoggerFactory.getLogger(RelationOperator::class.java) - fun createRelationFilterFields(type: GraphQLFieldsContainer, field: GraphQLFieldDefinition, filterType: String, builder: GraphQLInputObjectType.Builder) { + fun createRelationFilterFields(type: TypeDefinition<*>, field: FieldDefinition, filterType: String, builder: InputObjectTypeDefinition.Builder) { val list = field.type.isList() val addFilterField = { op: RelationOperator, description: String -> - builder.addFilterField(op.fieldName(field.name), false, filterType, description) + builder.addFilterField(op.fieldName(field.name), false, filterType, description.asDescription()) } addFilterField(EQ_OR_NOT_EXISTS, "Filters only those `${type.name}` for which ${if (list) "all" else "the"} `${field.name}`-relationship matches this filter. " + diff --git a/core/src/main/kotlin/org/neo4j/graphql/RelationDirection.kt b/core/src/main/kotlin/org/neo4j/graphql/RelationDirection.kt deleted file mode 100644 index d6a10edf..00000000 --- a/core/src/main/kotlin/org/neo4j/graphql/RelationDirection.kt +++ /dev/null @@ -1,14 +0,0 @@ -package org.neo4j.graphql - -enum class RelationDirection { - IN, - OUT, - BOTH; - - fun invert(): RelationDirection = when (this) { - IN -> OUT - OUT -> IN - else -> this - } - -} diff --git a/core/src/main/kotlin/org/neo4j/graphql/RelationshipInfo.kt b/core/src/main/kotlin/org/neo4j/graphql/RelationshipInfo.kt new file mode 100644 index 00000000..e597d90c --- /dev/null +++ b/core/src/main/kotlin/org/neo4j/graphql/RelationshipInfo.kt @@ -0,0 +1,80 @@ +package org.neo4j.graphql + +import graphql.language.ImplementingTypeDefinition +import graphql.schema.GraphQLDirective +import graphql.schema.GraphQLDirectiveContainer +import graphql.schema.GraphQLFieldsContainer +import graphql.schema.idl.TypeDefinitionRegistry +import org.neo4j.cypherdsl.core.Node +import org.neo4j.cypherdsl.core.Relationship +import org.neo4j.cypherdsl.core.SymbolicName + +data class RelationshipInfo( + val type: TYPE, + val typeName: String, + val relType: String, + val direction: RelationDirection, + val startField: String, + val endField: String +) { + + enum class RelationDirection { + IN, + OUT, + BOTH; + + fun invert(): RelationDirection = when (this) { + IN -> OUT + OUT -> IN + else -> this + } + + } + + companion object { + fun create(type: GraphQLFieldsContainer): RelationshipInfo? = (type as? GraphQLDirectiveContainer) + ?.getDirective(DirectiveConstants.RELATION) + ?.let { relDirective -> create(type, relDirective) } + + fun create(type: GraphQLFieldsContainer, relDirective: GraphQLDirective): RelationshipInfo { + val relType = relDirective.getArgument(DirectiveConstants.RELATION_NAME, "")!! + val direction = relDirective.getArgument(DirectiveConstants.RELATION_DIRECTION, null) + ?.let { RelationDirection.valueOf(it) } + ?: RelationDirection.OUT + + return RelationshipInfo( + type, + type.name, + relType, + direction, + relDirective.getMandatoryArgument(DirectiveConstants.RELATION_FROM), + relDirective.getMandatoryArgument(DirectiveConstants.RELATION_TO) + ) + } + + fun create(type: ImplementingTypeDefinition<*>, registry: TypeDefinitionRegistry): RelationshipInfo>? { + val relType = type.getDirectiveArgument(registry, DirectiveConstants.RELATION, DirectiveConstants.RELATION_NAME) + ?: return null + val startField = type.getMandatoryDirectiveArgument(registry, DirectiveConstants.RELATION, DirectiveConstants.RELATION_FROM) + val endField = type.getMandatoryDirectiveArgument(registry, DirectiveConstants.RELATION, DirectiveConstants.RELATION_TO) + val direction = type.getDirectiveArgument(registry, DirectiveConstants.RELATION, DirectiveConstants.RELATION_DIRECTION) + ?.let { RelationDirection.valueOf(it) } + ?: RelationDirection.OUT + return RelationshipInfo(type, type.name, relType, direction, startField, endField) + } + } + + fun createRelation(start: Node, end: Node, addType: Boolean = true, variable: SymbolicName? = null): Relationship { + val labels = if (addType) { + arrayOf(this.relType) + } else { + emptyArray() + } + return when (this.direction) { + RelationDirection.IN -> start.relationshipFrom(end, *labels) + RelationDirection.OUT -> start.relationshipTo(end, *labels) + RelationDirection.BOTH -> start.relationshipBetween(end, *labels) + } + .let { if (variable != null) it.named(variable) else it } + } +} diff --git a/core/src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt b/core/src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt index 87a48638..073be9bd 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt @@ -1,6 +1,5 @@ package org.neo4j.graphql -import graphql.Scalars import graphql.language.* import graphql.schema.* import graphql.schema.idl.RuntimeWiring @@ -16,37 +15,119 @@ import org.neo4j.graphql.handler.relation.CreateRelationTypeHandler import org.neo4j.graphql.handler.relation.DeleteRelationHandler /** - * Contains factory methods to generate an augmented graphql schema + * A class for augmenting a type definition registry and generate the corresponding data fetcher. + * There are factory methods, that can be used to simplify augmenting a schema. + * + * + * Generating the schema is done by invoking the following methods: + * 1. [augmentTypes] + * 2. [registerScalars] + * 3. [registerTypeNameResolver] + * 4. [registerDataFetcher] + * + * Each of these steps can be called manually to enhance an existing [TypeDefinitionRegistry] */ -object SchemaBuilder { +class SchemaBuilder( + val typeDefinitionRegistry: TypeDefinitionRegistry, + val schemaConfig: SchemaConfig = SchemaConfig() +) { + + companion object { + /** + * @param sdl the schema to augment + * @param config defines how the schema should get augmented + * @param dataFetchingInterceptor since this library registers dataFetcher for its augmented methods, these data + * fetchers may be called by other resolver. This interceptor will let you convert a cypher query into real data. + */ + @JvmStatic + @JvmOverloads + fun buildSchema(sdl: String, config: SchemaConfig = SchemaConfig(), dataFetchingInterceptor: DataFetchingInterceptor? = null): GraphQLSchema { + val schemaParser = SchemaParser() + val typeDefinitionRegistry = schemaParser.parse(sdl) + return buildSchema(typeDefinitionRegistry, config, dataFetchingInterceptor) + } + + /** + * @param typeDefinitionRegistry a registry containing all the types, that should be augmented + * @param config defines how the schema should get augmented + * @param dataFetchingInterceptor since this library registers dataFetcher for its augmented methods, these data + * fetchers may be called by other resolver. This interceptor will let you convert a cypher query into real data. + */ + @JvmStatic + @JvmOverloads + fun buildSchema(typeDefinitionRegistry: TypeDefinitionRegistry, config: SchemaConfig = SchemaConfig(), dataFetchingInterceptor: DataFetchingInterceptor? = null): GraphQLSchema { + + val builder = RuntimeWiring.newRuntimeWiring() + val codeRegistryBuilder = GraphQLCodeRegistry.newCodeRegistry() + + val schemaBuilder = SchemaBuilder(typeDefinitionRegistry, config) + schemaBuilder.augmentTypes() + schemaBuilder.registerScalars(builder) + schemaBuilder.registerTypeNameResolver(builder) + schemaBuilder.registerDataFetcher(codeRegistryBuilder, dataFetchingInterceptor) + + return SchemaGenerator().makeExecutableSchema( + typeDefinitionRegistry, + builder.codeRegistry(codeRegistryBuilder).build() + ) + } + } + + private val handler: List + private val neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + + init { + neo4jTypeDefinitionRegistry = getNeo4jEnhancements() + ensureRootQueryTypeExists(typeDefinitionRegistry) + handler = mutableListOf( + CypherDirectiveHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + AugmentFieldHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) + ) + if (schemaConfig.query.enabled) { + handler.add(QueryHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry)) + } + if (schemaConfig.mutation.enabled) { + handler += listOf( + MergeOrUpdateHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + DeleteHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + CreateTypeHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + DeleteRelationHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + CreateRelationTypeHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry), + CreateRelationHandler.Factory(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) + ) + } + } + /** - * @param sdl the schema to augment - * @param config defines how the schema should get augmented - * @param dataFetchingInterceptor since this library registers dataFetcher for its augmented methods, these data - * fetchers may be called by other resolver. This interceptor will let you convert a cypher query into real data. + * Generated additionally query and mutation fields according to the types present in the [typeDefinitionRegistry]. + * This method will also augment relation fields, so filtering and sorting is available for them */ - @JvmStatic - @JvmOverloads - fun buildSchema(sdl: String, config: SchemaConfig = SchemaConfig(), dataFetchingInterceptor: DataFetchingInterceptor? = null): GraphQLSchema { - val schemaParser = SchemaParser() - val typeDefinitionRegistry = schemaParser.parse(sdl) - return buildSchema(typeDefinitionRegistry, config, dataFetchingInterceptor) + fun augmentTypes() { + val queryTypeName = typeDefinitionRegistry.queryTypeName() + val mutationTypeName = typeDefinitionRegistry.mutationTypeName() + val subscriptionTypeName = typeDefinitionRegistry.subscriptionTypeName() + + typeDefinitionRegistry.types().values + .filterIsInstance>() + .filter { it.name != queryTypeName && it.name != mutationTypeName && it.name != subscriptionTypeName } + .forEach { type -> handler.forEach { h -> h.augmentType(type) } } + + // in a second run we enhance all the root fields + typeDefinitionRegistry.types().values + .filterIsInstance>() + .filter { it.name == queryTypeName || it.name == mutationTypeName || it.name == subscriptionTypeName } + .forEach { type -> handler.forEach { h -> h.augmentType(type) } } + + // TODO copy over only the types used in the source schema + typeDefinitionRegistry.merge(neo4jTypeDefinitionRegistry) } /** - * @param typeDefinitionRegistry a registry containing all the types, that should be augmented - * @param config defines how the schema should get augmented - * @param dataFetchingInterceptor since this library registers dataFetcher for its augmented methods, these data - * fetchers may be called by other resolver. This interceptor will let you convert a cypher query into real data. + * Register scalars of this library in the [RuntimeWiring][@param builder] + * @param builder a builder to create a runtime wiring */ - @JvmStatic - @JvmOverloads - fun buildSchema(typeDefinitionRegistry: TypeDefinitionRegistry, config: SchemaConfig = SchemaConfig(), dataFetchingInterceptor: DataFetchingInterceptor? = null): GraphQLSchema { - val enhancedRegistry = typeDefinitionRegistry.merge(getNeo4jEnhancements()) - ensureRootQueryTypeExists(enhancedRegistry) - - val builder = RuntimeWiring.newRuntimeWiring() + fun registerScalars(builder: RuntimeWiring.Builder) { typeDefinitionRegistry.scalars() .filterNot { entry -> GRAPHQL_SPECIFICATION_SCALARS_DEFINITIONS.containsKey(entry.key) } .forEach { (name, definition) -> @@ -62,26 +143,60 @@ object SchemaBuilder { } builder.scalar(scalar) } + } - - enhancedRegistry + /** + * Register type name resolver in the [RuntimeWiring][@param builder] + * @param builder a builder to create a runtime wiring + */ + fun registerTypeNameResolver(builder: RuntimeWiring.Builder) { + typeDefinitionRegistry .getTypes(InterfaceTypeDefinition::class.java) .forEach { typeDefinition -> builder.type(typeDefinition.name) { it.typeResolver { env -> (env.getObject() as? Map) - ?.let { data -> data.get(ProjectionBase.TYPE_NAME) as? String } + ?.let { data -> data[ProjectionBase.TYPE_NAME] as? String } ?.let { typeName -> env.schema.getObjectType(typeName) } } } } - val sourceSchema = SchemaGenerator().makeExecutableSchema(enhancedRegistry, builder.build()) + } - val handler = getHandler(config) + /** + * Register data fetcher in a [GraphQLCodeRegistry][@param codeRegistryBuilder]. + * The data fetcher of this library generate a cypher query and if provided use the dataFetchingInterceptor to run this cypher against a neo4j db. + * @param codeRegistryBuilder a builder to create a code registry + * @param dataFetchingInterceptor a function to convert a cypher string into an object by calling the neo4j db + */ + @JvmOverloads + fun registerDataFetcher( + codeRegistryBuilder: GraphQLCodeRegistry.Builder, + dataFetchingInterceptor: DataFetchingInterceptor?, + typeDefinitionRegistry: TypeDefinitionRegistry = this.typeDefinitionRegistry + ) { + addDataFetcher(typeDefinitionRegistry.queryTypeName(), OperationType.QUERY, dataFetchingInterceptor, codeRegistryBuilder) + addDataFetcher(typeDefinitionRegistry.mutationTypeName(), OperationType.MUTATION, dataFetchingInterceptor, codeRegistryBuilder) + } - var targetSchema = augmentSchema(sourceSchema, handler, config) - targetSchema = addDataFetcher(targetSchema, dataFetchingInterceptor, handler) - return targetSchema + private fun addDataFetcher( + parentType: String, + operationType: OperationType, + dataFetchingInterceptor: DataFetchingInterceptor?, + codeRegistryBuilder: GraphQLCodeRegistry.Builder) { + typeDefinitionRegistry.getType(parentType)?.unwrap() + ?.let { it as? ObjectTypeDefinition } + ?.fieldDefinitions + ?.forEach { field -> + handler.forEach { h -> + h.createDataFetcher(operationType, field)?.let { dataFetcher -> + val interceptedDataFetcher: DataFetcher<*> = dataFetchingInterceptor?.let { + DataFetcher { env -> dataFetchingInterceptor.fetchData(env, dataFetcher) } + } ?: dataFetcher + codeRegistryBuilder.dataFetcher(FieldCoordinates.coordinates(parentType, field.name), interceptedDataFetcher) + } + } + } } private fun ensureRootQueryTypeExists(enhancedRegistry: TypeDefinitionRegistry) { @@ -108,149 +223,9 @@ object SchemaBuilder { }) } - private fun getHandler(schemaConfig: SchemaConfig): List { - val handler = mutableListOf( - CypherDirectiveHandler.Factory(schemaConfig) - ) - if (schemaConfig.query.enabled) { - handler.add(QueryHandler.Factory(schemaConfig)) - } - if (schemaConfig.mutation.enabled) { - handler += listOf( - MergeOrUpdateHandler.Factory(schemaConfig), - DeleteHandler.Factory(schemaConfig), - CreateTypeHandler.Factory(schemaConfig), - DeleteRelationHandler.Factory(schemaConfig), - CreateRelationTypeHandler.Factory(schemaConfig), - CreateRelationHandler.Factory(schemaConfig) - ) - } - return handler - } - - private fun augmentSchema(sourceSchema: GraphQLSchema, handler: List, schemaConfig: SchemaConfig): GraphQLSchema { - val types = sourceSchema.typeMap.toMutableMap() - val env = BuildingEnv(types, sourceSchema, schemaConfig) - val queryTypeName = sourceSchema.queryTypeName() - val mutationTypeName = sourceSchema.mutationTypeName() - val subscriptionTypeName = sourceSchema.subscriptionTypeName() - types.values - .filterIsInstance() - .filter { - !it.name.startsWith("__") - && !it.isNeo4jType() - && it.name != queryTypeName - && it.name != mutationTypeName - && it.name != subscriptionTypeName - } - .forEach { type -> - handler.forEach { h -> h.augmentType(type, env) } - } - - // since new types my be added to `types` we copy the map, to safely modify the entries and later add these - // modified entries back to the `types` - val adjustedTypes = types.toMutableMap() - adjustedTypes.replaceAll { _, sourceType -> - when { - sourceType.name.startsWith("__") -> sourceType - sourceType is GraphQLObjectType -> sourceType.transform { builder -> - builder.clearFields().clearInterfaces() - // to prevent duplicated types in schema - sourceType.interfaces.forEach { builder.withInterface(GraphQLTypeReference(it.name)) } - sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env)) } - } - sourceType is GraphQLInterfaceType -> sourceType.transform { builder -> - builder.clearFields() - sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env)) } - } - else -> sourceType - } - } - types.putAll(adjustedTypes) - - return GraphQLSchema - .newSchema(sourceSchema) - .clearAdditionalTypes() - .query(types[queryTypeName] as? GraphQLObjectType) - .mutation(types[mutationTypeName] as? GraphQLObjectType) - .additionalTypes(types.values.toSet()) - .build() - } - - private fun enhanceRelations(fd: GraphQLFieldDefinition, env: BuildingEnv): GraphQLFieldDefinition { - return fd.transform { fieldBuilder -> - // to prevent duplicated types in schema - fieldBuilder.type(fd.type.ref() as GraphQLOutputType) - - if (!fd.isRelationship() || !fd.type.isList()) { - return@transform - } - - val fieldType = fd.type.inner() as? GraphQLFieldsContainer ?: return@transform - - if (env.schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE) { - - val optionsTypeName = env.addOptions(fieldType) - val optionsType = GraphQLTypeReference(optionsTypeName) - fieldBuilder.argument(input(ProjectionBase.OPTIONS, optionsType)) - - } else { - - if (fd.getArgument(ProjectionBase.FIRST) == null) { - fieldBuilder.argument { a -> a.name(ProjectionBase.FIRST).type(Scalars.GraphQLInt) } - } - if (fd.getArgument(ProjectionBase.OFFSET) == null) { - fieldBuilder.argument { a -> a.name(ProjectionBase.OFFSET).type(Scalars.GraphQLInt) } - } - if (fd.getArgument(ProjectionBase.ORDER_BY) == null) { - env.addOrdering(fieldType)?.let { orderingTypeName -> - val orderType = GraphQLList(GraphQLNonNull(GraphQLTypeReference(orderingTypeName))) - fieldBuilder.argument { a -> a.name(ProjectionBase.ORDER_BY).type(orderType) } - - } - } - - } - - val filterFieldName = if (env.schemaConfig.useWhereFilter) ProjectionBase.WHERE else ProjectionBase.FILTER - if (env.schemaConfig.query.enabled && !env.schemaConfig.query.exclude.contains(fieldType.name) && fd.getArgument(filterFieldName) == null) { - val filterTypeName = env.addFilterType(fieldType) - fieldBuilder.argument(input(filterFieldName, GraphQLTypeReference(filterTypeName))) - } - } - } - - private fun addDataFetcher(sourceSchema: GraphQLSchema, dataFetchingInterceptor: DataFetchingInterceptor?, handler: List): GraphQLSchema { - val codeRegistryBuilder = GraphQLCodeRegistry.newCodeRegistry(sourceSchema.codeRegistry) - addDataFetcher(sourceSchema.queryType, OperationType.QUERY, dataFetchingInterceptor, handler, codeRegistryBuilder) - addDataFetcher(sourceSchema.mutationType, OperationType.MUTATION, dataFetchingInterceptor, handler, codeRegistryBuilder) - return sourceSchema.transform { it.codeRegistry(codeRegistryBuilder.build()) } - } - - private fun addDataFetcher( - rootType: GraphQLObjectType?, - operationType: OperationType, - dataFetchingInterceptor: DataFetchingInterceptor?, - handler: List, - codeRegistryBuilder: GraphQLCodeRegistry.Builder) { - if (rootType == null) return - rootType.fieldDefinitions.forEach { field -> - handler.forEach { h -> - h.createDataFetcher(operationType, field)?.let { dataFetcher -> - val df: DataFetcher<*> = dataFetchingInterceptor?.let { - DataFetcher { env -> - dataFetchingInterceptor.fetchData(env, dataFetcher) - } - } ?: dataFetcher - codeRegistryBuilder.dataFetcher(rootType, field, df) - } - } - } - } - private fun getNeo4jEnhancements(): TypeDefinitionRegistry { - val directivesSdl = javaClass.getResource("/neo4j_types.graphql").readText() + - javaClass.getResource("/lib_directives.graphql").readText() + val directivesSdl = javaClass.getResource("/neo4j_types.graphql")?.readText() + + javaClass.getResource("/lib_directives.graphql")?.readText() val typeDefinitionRegistry = SchemaParser().parse(directivesSdl) neo4jTypeDefinitions .forEach { diff --git a/core/src/main/kotlin/org/neo4j/graphql/Translator.kt b/core/src/main/kotlin/org/neo4j/graphql/Translator.kt index c1c442ea..26a57523 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/Translator.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/Translator.kt @@ -1,7 +1,10 @@ package org.neo4j.graphql import graphql.execution.MergedField -import graphql.language.* +import graphql.language.Document +import graphql.language.Field +import graphql.language.FragmentDefinition +import graphql.language.OperationDefinition import graphql.language.OperationDefinition.Operation.MUTATION import graphql.language.OperationDefinition.Operation.QUERY import graphql.parser.Parser @@ -19,7 +22,7 @@ class Translator(val schema: GraphQLSchema) { @Throws(OptimizedQueryException::class) fun translate(query: String, params: Map = emptyMap(), ctx: QueryContext = QueryContext()): List { val ast = parse(query) // todo preparsedDocumentProvider - val fragments = ast.definitions.filterIsInstance().map { it.name to it }.toMap() + val fragments = ast.definitions.filterIsInstance().associateBy { it.name } return ast.definitions.filterIsInstance() .filter { it.operation == QUERY || it.operation == MUTATION } // todo variableDefinitions, directives, name .flatMap { operationDefinition -> @@ -67,6 +70,7 @@ class Translator(val schema: GraphQLSchema) { return dataFetcher.get(newDataFetchingEnvironment() .mergedField(MergedField.newMergedField(field).build()) + .parentType(operationObjectType) .graphQLSchema(schema) .fragmentsByName(fragments) .context(ctx) diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/AugmentFieldHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/AugmentFieldHandler.kt new file mode 100644 index 00000000..6cfa77ba --- /dev/null +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/AugmentFieldHandler.kt @@ -0,0 +1,71 @@ +package org.neo4j.graphql.handler + +import graphql.language.* +import graphql.schema.DataFetcher +import graphql.schema.idl.TypeDefinitionRegistry +import org.neo4j.graphql.* +import org.neo4j.graphql.handler.projection.ProjectionBase + +/** + * This class augments existing fields on a type and adds filtering and sorting to these fields. + */ +class AugmentFieldHandler( + schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry +) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { + val enhanceRelations = { type.fieldDefinitions.map { fieldDef -> fieldDef.transform { augmentRelation(it, fieldDef) } } } + + val rewritten = when (type) { + is ObjectTypeDefinition -> type.transform { it.fieldDefinitions(enhanceRelations()) } + is InterfaceTypeDefinition -> type.transform { it.definitions(enhanceRelations()) } + else -> return + } + + typeDefinitionRegistry.remove(rewritten.name, rewritten) + typeDefinitionRegistry.add(rewritten) + } + + private fun augmentRelation(fieldBuilder: FieldDefinition.Builder, field: FieldDefinition) { + if (!field.isRelationship() || !field.type.isList()) { + return + } + + val fieldType = field.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return + + if (schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE) { + + val optionsTypeName = addOptions(fieldType) + if (field.inputValueDefinitions.find { it.name == ProjectionBase.OPTIONS } == null) { + fieldBuilder.inputValueDefinition(input(ProjectionBase.OPTIONS, TypeName(optionsTypeName))) + } + + } else { + + if (field.inputValueDefinitions.find { it.name == ProjectionBase.FIRST } == null) { + fieldBuilder.inputValueDefinition(input(ProjectionBase.FIRST, TypeInt)) + } + if (field.inputValueDefinitions.find { it.name == ProjectionBase.OFFSET } == null) { + fieldBuilder.inputValueDefinition(input(ProjectionBase.OFFSET, TypeInt)) + } + if (field.inputValueDefinitions.find { it.name == ProjectionBase.ORDER_BY } == null) { + addOrdering(fieldType)?.let { orderingTypeName -> + val orderType = ListType(NonNullType(TypeName(orderingTypeName))) + fieldBuilder.inputValueDefinition(input(ProjectionBase.ORDER_BY, orderType)) + } + } + } + + val filterFieldName = if (schemaConfig.useWhereFilter) ProjectionBase.WHERE else ProjectionBase.FILTER + if (schemaConfig.query.enabled && !schemaConfig.query.exclude.contains(fieldType.name) && field.inputValueDefinitions.find { it.name == filterFieldName } == null) { + val filterTypeName = addFilterType(fieldType) + fieldBuilder.inputValueDefinition(input(filterFieldName, TypeName(filterTypeName))) + } + + } + + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? = null +} + diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcher.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcher.kt index 8a8715bb..58af868d 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcher.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcher.kt @@ -5,6 +5,7 @@ import graphql.language.VariableReference import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLType import org.neo4j.cypherdsl.core.Statement import org.neo4j.cypherdsl.core.renderer.Configuration import org.neo4j.cypherdsl.core.renderer.Renderer @@ -16,13 +17,15 @@ import org.neo4j.graphql.handler.projection.ProjectionBase /** * The is a base class for the implementation of graphql data fetcher used in this project */ -abstract class BaseDataFetcher(val fieldDefinition: GraphQLFieldDefinition, schemaConfig: SchemaConfig) : ProjectionBase(schemaConfig), DataFetcher { +abstract class BaseDataFetcher(schemaConfig: SchemaConfig) : ProjectionBase(schemaConfig), DataFetcher { - override fun get(env: DataFetchingEnvironment?): Cypher { - val field = env?.mergedField?.singleField + private var init = false + + override fun get(env: DataFetchingEnvironment): Cypher { + val field = env.mergedField?.singleField ?: throw IllegalAccessException("expect one filed in environment.mergedField") - require(field.name == fieldDefinition.name) { "Handler for ${fieldDefinition.name} cannot handle ${field.name}" } val variable = field.aliasOrName().decapitalize() + prepareDataFetcher(env.fieldDefinition, env.parentType) val statement = generateCypher(variable, field, env) val query = Renderer.getRenderer(Configuration @@ -36,7 +39,21 @@ abstract class BaseDataFetcher(val fieldDefinition: GraphQLFieldDefinition, sche (value as? VariableReference)?.let { env.variables[it.name] } ?: value } - return Cypher(query, params, fieldDefinition.type, variable = field.aliasOrName()) + return Cypher(query, params, env.fieldDefinition.type, variable = field.aliasOrName()) + } + + /** + * called after the schema is generated but before the 1st call + */ + private fun prepareDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { + if (init) { + return + } + init = true + initDataFetcher(fieldDefinition, parentType) + } + + protected open fun initDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { } protected abstract fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcherForContainer.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcherForContainer.kt index c406d4b5..3d0406f1 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcherForContainer.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/BaseDataFetcherForContainer.kt @@ -5,6 +5,7 @@ import graphql.language.ArrayValue import graphql.language.ObjectValue import graphql.schema.GraphQLFieldDefinition import graphql.schema.GraphQLFieldsContainer +import graphql.schema.GraphQLType import org.neo4j.cypherdsl.core.* import org.neo4j.cypherdsl.core.Cypher.* import org.neo4j.graphql.* @@ -12,16 +13,15 @@ import org.neo4j.graphql.* /** * This is a base class for all Node or Relation related data fetcher. */ -abstract class BaseDataFetcherForContainer( - val type: GraphQLFieldsContainer, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseDataFetcher(fieldDefinition, schemaConfig) { +abstract class BaseDataFetcherForContainer(schemaConfig: SchemaConfig) : BaseDataFetcher(schemaConfig) { + lateinit var type: GraphQLFieldsContainer val propertyFields: MutableMap List?> = mutableMapOf() val defaultFields: MutableMap = mutableMapOf() - init { + override fun initDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { + type = fieldDefinition.type.inner() as? GraphQLFieldsContainer + ?: throw IllegalStateException("expect type of field ${parentType.name()}.${fieldDefinition.name} to be GraphQLFieldsContainer, but was ${fieldDefinition.type.name()}") fieldDefinition .arguments .filterNot { listOf(FIRST, OFFSET, ORDER_BY, NATIVE_ID, OPTIONS).contains(it.name) } diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/CreateTypeHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/CreateTypeHandler.kt index cbc8ca32..55f8d474 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/CreateTypeHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/CreateTypeHandler.kt @@ -1,7 +1,13 @@ package org.neo4j.graphql.handler import graphql.language.Field -import graphql.schema.* +import graphql.language.FieldDefinition +import graphql.language.ImplementingTypeDefinition +import graphql.language.InterfaceTypeDefinition +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment +import graphql.schema.GraphQLObjectType +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Statement import org.neo4j.cypherdsl.core.StatementBuilder import org.neo4j.graphql.* @@ -10,58 +16,56 @@ import org.neo4j.graphql.* * This class handles all the logic related to the creation of nodes. * This includes the augmentation of the create<Node>-mutator and the related cypher generation */ -class CreateTypeHandler private constructor( - type: GraphQLFieldsContainer, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) { +class CreateTypeHandler private constructor(schemaConfig: SchemaConfig) : BaseDataFetcherForContainer(schemaConfig) { - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { if (!canHandle(type)) { return } val relevantFields = getRelevantFields(type) - val fieldDefinition = buildingEnv - .buildFieldDefinition("create", type, relevantFields, nullableResult = false) + val fieldDefinition = buildFieldDefinition("create", type, relevantFields, nullableResult = false) .build() - buildingEnv.addMutationField(fieldDefinition) + addMutationField(fieldDefinition) } - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.MUTATION) { return null } if (fieldDefinition.cypherDirective() != null) { return null } - val type = fieldDefinition.type.inner() as? GraphQLObjectType - ?: return null + val type = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandle(type)) { return null } - return when { - fieldDefinition.name == "create${type.name}" -> CreateTypeHandler(type, fieldDefinition, schemaConfig) + return when (fieldDefinition.name) { + "create${type.name}" -> CreateTypeHandler(schemaConfig) else -> null } } - private fun getRelevantFields(type: GraphQLFieldsContainer): List { + private fun getRelevantFields(type: ImplementingTypeDefinition<*>): List { return type - .relevantFields() + .getScalarFields() .filter { !it.isNativeId() } } - private fun canHandle(type: GraphQLFieldsContainer): Boolean { + private fun canHandle(type: ImplementingTypeDefinition<*>): Boolean { val typeName = type.name - if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName)) { + if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName) || isRootType(type)) { return false } - if (type !is GraphQLObjectType) { + if (type is InterfaceTypeDefinition) { return false } - if ((type as GraphQLDirectiveContainer).isRelationType()) { + if (type.relationship() != null) { // relations are handled by the CreateRelationTypeHandler return false } diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/CypherDirectiveHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/CypherDirectiveHandler.kt index 1b72e7fe..923dc66b 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/CypherDirectiveHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/CypherDirectiveHandler.kt @@ -1,10 +1,11 @@ package org.neo4j.graphql.handler import graphql.language.Field +import graphql.language.FieldDefinition import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment -import graphql.schema.GraphQLFieldDefinition import graphql.schema.GraphQLFieldsContainer +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Functions import org.neo4j.cypherdsl.core.Statement import org.neo4j.graphql.* @@ -12,25 +13,25 @@ import org.neo4j.graphql.* /** * This class handles all logic related to custom Cypher queries declared by fields with a @cypher directive */ -class CypherDirectiveHandler( - private val type: GraphQLFieldsContainer?, - private val isQuery: Boolean, - private val cypherDirective: CypherDirective, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig) - : BaseDataFetcher(fieldDefinition, schemaConfig) { - - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { - val cypherDirective = fieldDefinition.cypherDirective() ?: return null - val type = fieldDefinition.type.inner() as? GraphQLFieldsContainer +class CypherDirectiveHandler(private val isQuery: Boolean, schemaConfig: SchemaConfig) : BaseDataFetcher(schemaConfig) { + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { + fieldDefinition.cypherDirective() ?: return null val isQuery = operationType == OperationType.QUERY - return CypherDirectiveHandler(type, isQuery, cypherDirective, fieldDefinition, schemaConfig) + return CypherDirectiveHandler(isQuery, schemaConfig) } } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { + val fieldDefinition = env.fieldDefinition + val type = fieldDefinition.type.inner() as? GraphQLFieldsContainer + val cypherDirective = fieldDefinition.cypherDirective() + ?: throw IllegalStateException("Expect field ${env.logField()} to have @cypher directive present") val query = if (isQuery) { val nestedQuery = cypherDirective(variable, fieldDefinition, field, cypherDirective) diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/DeleteHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/DeleteHandler.kt index 02e78182..8761235e 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/DeleteHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/DeleteHandler.kt @@ -1,7 +1,14 @@ package org.neo4j.graphql.handler import graphql.language.Field -import graphql.schema.* +import graphql.language.FieldDefinition +import graphql.language.ImplementingTypeDefinition +import graphql.language.TypeName +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment +import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLType +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Node import org.neo4j.cypherdsl.core.Relationship import org.neo4j.cypherdsl.core.Statement @@ -12,57 +19,62 @@ import org.neo4j.graphql.* * This class handles all the logic related to the deletion of nodes. * This includes the augmentation of the delete<Node>-mutator and the related cypher generation */ -class DeleteHandler private constructor( - type: GraphQLFieldsContainer, - private val idField: GraphQLFieldDefinition, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig, - private val isRelation: Boolean = type.isRelationType() -) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) { +class DeleteHandler private constructor(schemaConfig: SchemaConfig) : BaseDataFetcherForContainer(schemaConfig) { - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { + private lateinit var idField: GraphQLFieldDefinition + private var isRelation: Boolean = false + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { if (!canHandle(type)) { return } val idField = type.getIdField() ?: return - val fieldDefinition = buildingEnv - .buildFieldDefinition("delete", type, listOf(idField), nullableResult = true) - .description("Deletes ${type.name} and returns the type itself") - .type(type.ref() as GraphQLOutputType) + val fieldDefinition = buildFieldDefinition("delete", type, listOf(idField), nullableResult = true) + .description("Deletes ${type.name} and returns the type itself".asDescription()) + .type(TypeName(type.name)) .build() - buildingEnv.addMutationField(fieldDefinition) + addMutationField(fieldDefinition) } - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.MUTATION) { return null } if (fieldDefinition.cypherDirective() != null) { return null } - val type = fieldDefinition.type as? GraphQLFieldsContainer - ?: return null + val type = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandle(type)) { return null } - val idField = type.getIdField() ?: return null + type.getIdField() ?: return null return when (fieldDefinition.name) { - "delete${type.name}" -> DeleteHandler(type, idField, fieldDefinition, schemaConfig) + "delete${type.name}" -> DeleteHandler(schemaConfig) else -> null } } - private fun canHandle(type: GraphQLFieldsContainer): Boolean { + private fun canHandle(type: ImplementingTypeDefinition<*>): Boolean { val typeName = type.name - if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName)) { + if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName) || isRootType(type)) { return false } return type.getIdField() != null } } + override fun initDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { + super.initDataFetcher(fieldDefinition, parentType) + idField = type.getIdField() ?: throw IllegalStateException("Cannot resolve id field for type ${type.name}") + isRelation = type.isRelationType() + } + override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { val idArg = field.arguments.first { it.name == idField.name } diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/MergeOrUpdateHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/MergeOrUpdateHandler.kt index c9496fad..6ceba797 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/MergeOrUpdateHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/MergeOrUpdateHandler.kt @@ -1,10 +1,13 @@ package org.neo4j.graphql.handler import graphql.language.Field +import graphql.language.FieldDefinition +import graphql.language.ImplementingTypeDefinition import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLFieldDefinition -import graphql.schema.GraphQLFieldsContainer +import graphql.schema.GraphQLType +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Node import org.neo4j.cypherdsl.core.Relationship import org.neo4j.cypherdsl.core.Statement @@ -15,62 +18,59 @@ import org.neo4j.graphql.* * This class handles all the logic related to the updating of nodes. * This includes the augmentation of the update<Node> and merge<Node>-mutator and the related cypher generation */ -class MergeOrUpdateHandler private constructor( - type: GraphQLFieldsContainer, - private val merge: Boolean, - private val idField: GraphQLFieldDefinition, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig, - private val isRelation: Boolean = type.isRelationType() -) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) { - - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { +class MergeOrUpdateHandler private constructor(private val merge: Boolean, schemaConfig: SchemaConfig) : BaseDataFetcherForContainer(schemaConfig) { + + private lateinit var idField: GraphQLFieldDefinition + private var isRelation: Boolean = false + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { if (!canHandle(type)) { return } - val relevantFields = type.relevantFields() - val mergeField = buildingEnv - .buildFieldDefinition("merge", type, relevantFields, nullableResult = false) + val relevantFields = type.getScalarFields() + val mergeField = buildFieldDefinition("merge", type, relevantFields, nullableResult = false) .build() - buildingEnv.addMutationField(mergeField) + addMutationField(mergeField) - val updateField = buildingEnv - .buildFieldDefinition("update", type, relevantFields, nullableResult = true) + val updateField = buildFieldDefinition("update", type, relevantFields, nullableResult = true) .build() - buildingEnv.addMutationField(updateField) + addMutationField(updateField) } - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.MUTATION) { return null } if (fieldDefinition.cypherDirective() != null) { return null } - val type = fieldDefinition.type.inner() as? GraphQLFieldsContainer - ?: return null + val type = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandle(type)) { return null } - val idField = type.getIdField() ?: return null + type.getIdField() ?: return null return when (fieldDefinition.name) { - "merge${type.name}" -> MergeOrUpdateHandler(type, true, idField, fieldDefinition, schemaConfig) - "update${type.name}" -> MergeOrUpdateHandler(type, false, idField, fieldDefinition, schemaConfig) + "merge${type.name}" -> MergeOrUpdateHandler(true, schemaConfig) + "update${type.name}" -> MergeOrUpdateHandler(false, schemaConfig) else -> null } } - private fun canHandle(type: GraphQLFieldsContainer): Boolean { + private fun canHandle(type: ImplementingTypeDefinition<*>): Boolean { val typeName = type.name - if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName)) { + if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName) || isRootType(type)) { return false } if (type.getIdField() == null) { return false } - if (type.relevantFields().none { !it.isID() }) { + if (type.getScalarFields().none { !it.type.inner().isID() }) { // nothing to update (except ID) return false } @@ -78,9 +78,15 @@ class MergeOrUpdateHandler private constructor( } } - init { + override fun initDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { + super.initDataFetcher(fieldDefinition, parentType) + + idField = type.getIdField() ?: throw IllegalStateException("Cannot resolve id field for type ${type.name}") + isRelation = type.isRelationType() + defaultFields.clear() // for merge or updates we do not reset to defaults propertyFields.remove(idField.name) // id should not be updated + } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt index c0cf25bc..c8800942 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt @@ -1,8 +1,9 @@ package org.neo4j.graphql.handler -import graphql.Scalars -import graphql.language.Field -import graphql.schema.* +import graphql.language.* +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment +import graphql.schema.idl.TypeDefinitionRegistry import org.atteo.evo.inflector.English import org.neo4j.cypherdsl.core.Cypher.* import org.neo4j.cypherdsl.core.Statement @@ -13,58 +14,56 @@ import org.neo4j.graphql.handler.filter.OptimizedFilterHandler * This class handles all the logic related to the querying of nodes and relations. * This includes the augmentation of the query-fields and the related cypher generation */ -class QueryHandler private constructor( - type: GraphQLFieldsContainer, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) { - - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { +class QueryHandler private constructor(schemaConfig: SchemaConfig) : BaseDataFetcherForContainer(schemaConfig) { + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { if (!canHandle(type)) { return } val typeName = type.name val relevantFields = getRelevantFields(type) - val filterTypeName = buildingEnv.addFilterType(type) + val filterTypeName = addFilterType(type) val arguments = if (schemaConfig.useWhereFilter) { - listOf(input(WHERE, GraphQLTypeReference(filterTypeName))) + listOf(input(WHERE, TypeName(filterTypeName))) } else { - buildingEnv.getInputValueDefinitions(relevantFields, { true }) + - input(FILTER, GraphQLTypeReference(filterTypeName)) + getInputValueDefinitions(relevantFields, { true }) + + input(FILTER, TypeName(filterTypeName)) } var fieldName = if (schemaConfig.capitalizeQueryFields) typeName else typeName.decapitalize() if (schemaConfig.pluralizeFields) { fieldName = English.plural(fieldName) } - val builder = GraphQLFieldDefinition + val builder = FieldDefinition .newFieldDefinition() .name(fieldName) - .arguments(arguments) - .type(GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLTypeReference(type.name))))) + .inputValueDefinitions(arguments.toMutableList()) + .type(NonNullType(ListType(NonNullType(TypeName(type.name))))) if (schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE) { - val optionsTypeName = buildingEnv.addOptions(type) - val optionsType = GraphQLTypeReference(optionsTypeName) - builder.argument(input(OPTIONS, optionsType)) + val optionsTypeName = addOptions(type) + builder.inputValueDefinition(input(OPTIONS, TypeName(optionsTypeName))) } else { builder - .argument(input(FIRST, Scalars.GraphQLInt)) - .argument(input(OFFSET, Scalars.GraphQLInt)) + .inputValueDefinition(input(FIRST, TypeInt)) + .inputValueDefinition(input(OFFSET, TypeInt)) - val orderingTypeName = buildingEnv.addOrdering(type) + val orderingTypeName = addOrdering(type) if (orderingTypeName != null) { - val orderType = GraphQLList(GraphQLNonNull(GraphQLTypeReference(orderingTypeName))) - builder.argument(input(ORDER_BY, orderType)) + builder.inputValueDefinition(input(ORDER_BY, ListType(NonNullType(TypeName(orderingTypeName))))) } } val def = builder.build() - buildingEnv.addQueryField(def) + addQueryField(def) } - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.QUERY) { return null } @@ -72,17 +71,16 @@ class QueryHandler private constructor( if (cypherDirective != null) { return null } - val type = fieldDefinition.type.inner() as? GraphQLFieldsContainer - ?: return null + val type = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandle(type)) { return null } - return QueryHandler(type, fieldDefinition, schemaConfig) + return QueryHandler(schemaConfig) } - private fun canHandle(type: GraphQLFieldsContainer): Boolean { - val typeName = type.innerName() - if (!schemaConfig.query.enabled || schemaConfig.query.exclude.contains(typeName)) { + private fun canHandle(type: ImplementingTypeDefinition<*>): Boolean { + val typeName = type.name + if (!schemaConfig.query.enabled || schemaConfig.query.exclude.contains(typeName) || isRootType(type)) { return false } if (getRelevantFields(type).isEmpty() && !hasRelationships(type)) { @@ -91,16 +89,18 @@ class QueryHandler private constructor( return true } - private fun hasRelationships(type: GraphQLFieldsContainer): Boolean = type.fieldDefinitions.any { it.isRelationship() } + private fun hasRelationships(type: ImplementingTypeDefinition<*>): Boolean = type.fieldDefinitions.any { it.isRelationship() } - private fun getRelevantFields(type: GraphQLFieldsContainer): List { + private fun getRelevantFields(type: ImplementingTypeDefinition<*>): List { return type - .relevantFields() + .getScalarFields() .filter { it.dynamicPrefix() == null } // TODO currently we do not support filtering on dynamic properties } } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { + val fieldDefinition = env.fieldDefinition + val type = env.typeAsContainer() val (propertyContainer, match) = when { type.isRelationType() -> anyNode().relationshipTo(anyNode(), type.label()).named(variable) diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/projection/ProjectionBase.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/projection/ProjectionBase.kt index 10cc73d1..9ee44e0e 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/projection/ProjectionBase.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/projection/ProjectionBase.kt @@ -373,7 +373,12 @@ open class ProjectionBase( return skipLimit.slice(fieldType.isList(), comprehension) } - private fun relationshipInfoInCorrectDirection(fieldObjectType: GraphQLFieldsContainer, relInfo0: RelationshipInfo, parent: GraphQLFieldsContainer, relDirectiveField: RelationshipInfo?): RelationshipInfo { + private fun relationshipInfoInCorrectDirection( + fieldObjectType: GraphQLFieldsContainer, + relInfo0: RelationshipInfo, + parent: GraphQLFieldsContainer, + relDirectiveField: RelationshipInfo? + ): RelationshipInfo { val startField = fieldObjectType.getFieldDefinition(relInfo0.startField)!! val endField = fieldObjectType.getFieldDefinition(relInfo0.endField)!! val startFieldTypeName = startField.type.innerName() @@ -413,11 +418,7 @@ open class ProjectionBase( relInfo.endField -> Triple(anyNode(), node, node) else -> throw IllegalArgumentException("type ${parent.name} does not have a matching field with name ${fieldDefinition.name}") } - val rel = when (relInfo.direction) { - RelationDirection.IN -> start.relationshipFrom(end).named(variable) - RelationDirection.OUT -> start.relationshipTo(end).named(variable) - RelationDirection.BOTH -> start.relationshipBetween(end).named(variable) - } + val rel = relInfo.createRelation(start, end, false,variable) return head(CypherDSL.listBasedOn(rel).returning(target.project(projectFields(target, field, fieldDefinition.type as GraphQLFieldsContainer, env)))) } @@ -426,8 +427,8 @@ open class ProjectionBase( val nodeType = fieldType.getInnerFieldsContainer() // todo combine both nestings if rel-entity - val relDirectiveObject = (nodeType as? GraphQLDirectiveContainer)?.getDirective(DirectiveConstants.RELATION)?.let { relDetails(nodeType, it) } - val relDirectiveField = fieldDefinition.getDirective(DirectiveConstants.RELATION)?.let { relDetails(nodeType, it) } + val relDirectiveObject = (nodeType as? GraphQLDirectiveContainer)?.getDirective(DirectiveConstants.RELATION)?.let { RelationshipInfo.create(nodeType, it) } + val relDirectiveField = fieldDefinition.getDirective(DirectiveConstants.RELATION)?.let { RelationshipInfo.create(nodeType, it) } val (relInfo0, isRelFromType) = relDirectiveObject?.let { it to true } diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/BaseRelationHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/BaseRelationHandler.kt index 6810b183..ea81fdfa 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/BaseRelationHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/BaseRelationHandler.kt @@ -1,8 +1,11 @@ package org.neo4j.graphql.handler.relation -import graphql.Scalars -import graphql.language.Argument -import graphql.schema.* +import graphql.language.* +import graphql.schema.DataFetcher +import graphql.schema.GraphQLFieldDefinition +import graphql.schema.GraphQLFieldsContainer +import graphql.schema.GraphQLType +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Condition import org.neo4j.cypherdsl.core.Node import org.neo4j.graphql.* @@ -11,57 +14,52 @@ import org.neo4j.graphql.handler.BaseDataFetcherForContainer /** * This is a base class for all handler acting on relations / edges */ -abstract class BaseRelationHandler( - type: GraphQLFieldsContainer, - val relation: RelationshipInfo, - private val startId: RelationshipInfo.RelatedField, - private val endId: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig) - : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) { - - init { - propertyFields.remove(startId.argumentName) - propertyFields.remove(endId.argumentName) - } +abstract class BaseRelationHandler(val prefix: String, schemaConfig: SchemaConfig) : BaseDataFetcherForContainer(schemaConfig) { + + lateinit var relation: RelationshipInfo + lateinit var startId: RelatedField + lateinit var endId: RelatedField + + data class RelatedField( + val argumentName: String, + val field: GraphQLFieldDefinition, + val declaringType: GraphQLFieldsContainer + ) - abstract class BaseRelationFactory(val prefix: String, schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { + abstract class BaseRelationFactory( + val prefix: String, + schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { protected fun buildFieldDefinition( - source: GraphQLFieldsContainer, - targetField: GraphQLFieldDefinition, + source: ImplementingTypeDefinition<*>, + targetField: FieldDefinition, nullableResult: Boolean - ): GraphQLFieldDefinition.Builder? { + ): FieldDefinition.Builder? { - val targetType = targetField.type.getInnerFieldsContainer() - val sourceIdField = source.getIdField() - val targetIdField = targetType.getIdField() - if (sourceIdField == null || targetIdField == null) { - return null - } + val (sourceIdField, _) = getRelationFields(source, targetField) ?: return null val targetFieldName = targetField.name.capitalize() - val idType = GraphQLNonNull(Scalars.GraphQLID) - val targetIDType = if (targetField.type.isList()) GraphQLNonNull(GraphQLList(idType)) else idType + val idType = NonNullType(TypeID) + val targetIDType = if (targetField.type.isList()) NonNullType(ListType(idType)) else idType - var type: GraphQLOutputType = source + var type: Type<*> = TypeName(source.name) if (!nullableResult) { - type = GraphQLNonNull(type) + type = NonNullType(type) } - return GraphQLFieldDefinition.newFieldDefinition() + return FieldDefinition.newFieldDefinition() .name("$prefix${source.name}$targetFieldName") - .argument(input(sourceIdField.name, idType)) - .argument(input(targetField.name, targetIDType)) - .type(type.ref() as GraphQLOutputType) + .inputValueDefinition(input(sourceIdField.name, idType)) + .inputValueDefinition(input(targetField.name, targetIDType)) + .type(type) } - protected fun canHandleType(type: GraphQLFieldsContainer): Boolean { - if (type !is GraphQLObjectType) { - return false - } - if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(type.name)) { + protected fun canHandleType(type: ImplementingTypeDefinition<*>): Boolean { + if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(type.name) || isRootType(type)) { // TODO we do not mutate the node but the relation, I think this check should be different return false } @@ -71,9 +69,9 @@ abstract class BaseRelationHandler( return true } - protected fun canHandleField(targetField: GraphQLFieldDefinition): Boolean { - val type = targetField.type.inner() as? GraphQLObjectType ?: return false - if (targetField.getDirective(DirectiveConstants.RELATION) == null) { + protected fun canHandleField(targetField: FieldDefinition): Boolean { + val type = targetField.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return false + if (!targetField.hasDirective(DirectiveConstants.RELATION)) { return false } if (type.getIdField() == null) { @@ -82,15 +80,14 @@ abstract class BaseRelationHandler( return true } - final override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + final override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.MUTATION) { return null } if (fieldDefinition.cypherDirective() != null) { return null } - val sourceType = fieldDefinition.type.inner() as? GraphQLFieldsContainer - ?: return null + val sourceType = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandleType(sourceType)) { return null } @@ -107,32 +104,77 @@ abstract class BaseRelationHandler( if (!canHandleField(targetField)) { return null } - val relation = sourceType.relationshipFor(targetField.name) ?: return null + if (!sourceType.hasRelationshipFor(targetField.name)) { + return null + } + + val (_, _) = getRelationFields(sourceType, targetField) ?: return null + return createDataFetcher() + } - val targetType = targetField.type.getInnerFieldsContainer() - val sourceIdField = sourceType.getIdField() + private fun ImplementingTypeDefinition<*>.hasRelationshipFor(name: String): Boolean { + val field = getFieldDefinition(name) + ?: throw IllegalArgumentException("$name is not defined on ${this.name}") + val fieldObjectType = field.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return false + + val target = getDirectiveArgument(DirectiveConstants.RELATION, DirectiveConstants.RELATION_TO, null) + if (target != null) { + if (fieldObjectType.getFieldDefinition(target)?.name == this.name) return true + } else { + if (fieldObjectType.getDirective(DirectiveConstants.RELATION) != null) return true + if (field.getDirective(DirectiveConstants.RELATION) != null) return true + } + return false + } + + abstract fun createDataFetcher(): DataFetcher? + + private fun getRelationFields(source: ImplementingTypeDefinition<*>, targetField: FieldDefinition): Pair? { + val targetType = targetField.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null + val sourceIdField = source.getIdField() val targetIdField = targetType.getIdField() if (sourceIdField == null || targetIdField == null) { return null } - val startIdField = RelationshipInfo.RelatedField(sourceIdField.name, sourceIdField, sourceType) - val endIdField = RelationshipInfo.RelatedField(targetField.name, targetIdField, targetType) - return createDataFetcher(sourceType, relation, startIdField, endIdField, fieldDefinition) + return sourceIdField to targetIdField } + } + + override fun initDataFetcher(fieldDefinition: GraphQLFieldDefinition, parentType: GraphQLType) { + super.initDataFetcher(fieldDefinition, parentType) + + initRelation(fieldDefinition) + + propertyFields.remove(startId.argumentName) + propertyFields.remove(endId.argumentName) + } + + protected open fun initRelation(fieldDefinition: GraphQLFieldDefinition) { + val p = "$prefix${type.name}" + + val targetField = fieldDefinition.name + .removePrefix(p) + .decapitalize() + .let { + type.getFieldDefinition(it) ?: throw IllegalStateException("Cannot find field $it on type ${type.name}") + } + - abstract fun createDataFetcher( - sourceType: GraphQLFieldsContainer, - relation: RelationshipInfo, - startIdField: RelationshipInfo.RelatedField, - endIdField: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition - ): DataFetcher? + relation = type.relationshipFor(targetField.name) + ?: throw IllegalStateException("Cannot resolve relationship for ${targetField.name} on type ${type.name}") + val targetType = targetField.type.getInnerFieldsContainer() + val sourceIdField = type.getIdField() + ?: throw IllegalStateException("Cannot find id field for type ${type.name}") + val targetIdField = targetType.getIdField() + ?: throw IllegalStateException("Cannot find id field for type ${targetType.name}") + startId = RelatedField(sourceIdField.name, sourceIdField, type) + endId = RelatedField(targetField.name, targetIdField, targetType) } fun getRelationSelect(start: Boolean, arguments: Map): Pair { val relFieldName: String - val idField: RelationshipInfo.RelatedField + val idField: RelatedField if (start) { relFieldName = relation.startField idField = startId diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationHandler.kt index f649b220..22d2ee8e 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationHandler.kt @@ -1,10 +1,10 @@ package org.neo4j.graphql.handler.relation import graphql.language.Field +import graphql.language.ImplementingTypeDefinition import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment -import graphql.schema.GraphQLFieldDefinition -import graphql.schema.GraphQLFieldsContainer +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Statement import org.neo4j.cypherdsl.core.StatementBuilder.OngoingUpdate import org.neo4j.graphql.* @@ -13,20 +13,25 @@ import org.neo4j.graphql.* * This class handles all the logic related to the creation of relations starting from an existing node. * This includes the augmentation of the add<Edge>-mutator and the related cypher generation */ -class CreateRelationHandler private constructor( - type: GraphQLFieldsContainer, - relation: RelationshipInfo, - startId: RelationshipInfo.RelatedField, - endId: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseRelationHandler(type, relation, startId, endId, fieldDefinition, schemaConfig) { - - class Factory(schemaConfig: SchemaConfig) : BaseRelationFactory("add", schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { +class CreateRelationHandler private constructor(schemaConfig: SchemaConfig) : BaseRelationHandler("add", schemaConfig) { + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : BaseRelationFactory("add", schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { + if (!canHandleType(type)) { return } + + val richRelationTypes = typeDefinitionRegistry.types().values + .filterIsInstance>() + .filter { it.getDirective(DirectiveConstants.RELATION) != null } + .associate { it.getDirectiveArgument(DirectiveConstants.RELATION, DirectiveConstants.RELATION_NAME, null)!! to it.name } + + type.fieldDefinitions .filter { canHandleField(it) } .mapNotNull { targetField -> @@ -35,36 +40,30 @@ class CreateRelationHandler private constructor( val relationType = targetField .getDirectiveArgument(DirectiveConstants.RELATION, DirectiveConstants.RELATION_NAME, null) - ?.let { buildingEnv.getTypeForRelation(it) } + ?.let { it -> (richRelationTypes[it]) } + ?.let { typeDefinitionRegistry.getUnwrappedType(it) as? ImplementingTypeDefinition } relationType ?.fieldDefinitions - ?.filter { it.type.isScalar() && !it.isID() } - ?.forEach { builder.argument(input(it.name, it.type)) } + ?.filter { it.type.inner().isScalar() && !it.type.inner().isID() } + ?.forEach { builder.inputValueDefinition(input(it.name, it.type)) } - buildingEnv.addMutationField(builder.build()) + addMutationField(builder.build()) } } } - override fun createDataFetcher( - sourceType: GraphQLFieldsContainer, - relation: RelationshipInfo, - startIdField: RelationshipInfo.RelatedField, - endIdField: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition - ): DataFetcher { - return CreateRelationHandler(sourceType, relation, startIdField, endIdField, fieldDefinition, schemaConfig) + override fun createDataFetcher(): DataFetcher { + return CreateRelationHandler(schemaConfig) } - } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { val properties = properties(variable, field.arguments) - val arguments = field.arguments.map { it.name to it }.toMap() + val arguments = field.arguments.associateBy { it.name } val (startNode, startWhere) = getRelationSelect(true, arguments) val (endNode, endWhere) = getRelationSelect(false, arguments) diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationTypeHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationTypeHandler.kt index bbd932ae..541d6b94 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationTypeHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/CreateRelationTypeHandler.kt @@ -1,7 +1,11 @@ package org.neo4j.graphql.handler.relation import graphql.language.Field +import graphql.language.FieldDefinition +import graphql.language.ImplementingTypeDefinition +import graphql.language.InterfaceTypeDefinition import graphql.schema.* +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Cypher.name import org.neo4j.cypherdsl.core.Statement import org.neo4j.cypherdsl.core.StatementBuilder @@ -11,83 +15,71 @@ import org.neo4j.graphql.* * This class handles all the logic related to the creation of relations. * This includes the augmentation of the create<Edge>-mutator and the related cypher generation */ -class CreateRelationTypeHandler private constructor( - type: GraphQLFieldsContainer, - relation: RelationshipInfo, - startId: RelationshipInfo.RelatedField, - endId: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseRelationHandler(type, relation, startId, endId, fieldDefinition, schemaConfig) { - - class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { +class CreateRelationTypeHandler private constructor(schemaConfig: SchemaConfig) : BaseRelationHandler("create", schemaConfig) { + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : AugmentationHandler(schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { if (!canHandle(type)) { return } - val relation = type.relationship()!! - val startIdField = relation.getStartFieldId() - val endIdField = relation.getEndFieldId() + val relation = type.relationship() ?: return + val startIdField = getRelatedIdField(relation, relation.startField) + val endIdField = getRelatedIdField(relation, relation.endField) if (startIdField == null || endIdField == null) { return } - val relevantFields = getRelevantFields(type) val createArgs = getRelevantFields(type) .filter { !it.isNativeId() } .filter { it.name != startIdField.argumentName } .filter { it.name != endIdField.argumentName } - val builder = buildingEnv - .buildFieldDefinition("create", type, relevantFields, nullableResult = false) - .argument(input(startIdField.argumentName, startIdField.field.type)) - .argument(input(endIdField.argumentName, endIdField.field.type)) - - createArgs.forEach { builder.argument(input(it.name, it.type)) } + val builder = + buildFieldDefinition("create", type, createArgs, nullableResult = false) + .inputValueDefinition(input(startIdField.argumentName, startIdField.field.type)) + .inputValueDefinition(input(endIdField.argumentName, endIdField.field.type)) - buildingEnv.addMutationField(builder.build()) + addMutationField(builder.build()) } - override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher? { + override fun createDataFetcher(operationType: OperationType, fieldDefinition: FieldDefinition): DataFetcher? { if (operationType != OperationType.MUTATION) { return null } if (fieldDefinition.cypherDirective() != null) { return null } - val type = fieldDefinition.type.inner() as? GraphQLObjectType - ?: return null + val type = fieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> ?: return null if (!canHandle(type)) { return null } if (fieldDefinition.name != "create${type.name}") { return null } - - val relation = type.relationship() ?: return null - val startIdField = relation.getStartFieldId() ?: return null - val endIdField = relation.getEndFieldId() ?: return null - - return CreateRelationTypeHandler(type, relation, startIdField, endIdField, fieldDefinition, schemaConfig) + return CreateRelationTypeHandler(schemaConfig) } - private fun getRelevantFields(type: GraphQLFieldsContainer): List { + private fun getRelevantFields(type: ImplementingTypeDefinition<*>): List { return type - .relevantFields() + .getScalarFields() .filter { !it.isNativeId() } } - private fun canHandle(type: GraphQLFieldsContainer): Boolean { + private fun canHandle(type: ImplementingTypeDefinition<*>): Boolean { val typeName = type.name if (!schemaConfig.mutation.enabled || schemaConfig.mutation.exclude.contains(typeName)) { return false } - if (type !is GraphQLObjectType) { + if (type is InterfaceTypeDefinition) { return false } val relation = type.relationship() ?: return false - val startIdField = relation.getStartFieldId() - val endIdField = relation.getEndFieldId() + val startIdField = getRelatedIdField(relation, relation.startField) + val endIdField = getRelatedIdField(relation, relation.endField) if (startIdField == null || endIdField == null) { return false } @@ -99,12 +91,49 @@ class CreateRelationTypeHandler private constructor( return true } + + data class RelatedField( + val argumentName: String, + val field: FieldDefinition, + ) + + private fun getRelatedIdField(info: RelationshipInfo>, relFieldName: String?): RelatedField? { + if (relFieldName == null) return null + val relFieldDefinition = info.type.getFieldDefinition(relFieldName) + ?: throw IllegalArgumentException("field $relFieldName does not exists on ${info.typeName}") + + val relType = relFieldDefinition.type.inner().resolve() as? ImplementingTypeDefinition<*> + ?: throw IllegalArgumentException("type ${relFieldDefinition.type.name()} not found") + return relType.fieldDefinitions.filter { it.type.inner().isID() } + .map { RelatedField(normalizeFieldName(relFieldName, it.name), it) } + .firstOrNull() + } + + } + + private fun getRelatedIdField(info: RelationshipInfo, relFieldName: String): RelatedField { + val relFieldDefinition = info.type.getFieldDefinition(relFieldName) + ?: throw IllegalArgumentException("field $relFieldName does not exists on ${info.typeName}") + + val relType = relFieldDefinition.type.inner() as? GraphQLImplementingType + ?: throw IllegalArgumentException("type ${relFieldDefinition.type.name()} not found") + return relType.fieldDefinitions.filter { it.isID() } + .map { RelatedField(normalizeFieldName(relFieldName, it.name), it, relType) } + .firstOrNull() + ?: throw IllegalStateException("Cannot find id field for type ${info.typeName}") + } + + override fun initRelation(fieldDefinition: GraphQLFieldDefinition) { + relation = type.relationship() + ?: throw IllegalStateException("Cannot resolve relationship for type ${type.name}") + startId = getRelatedIdField(relation, relation.startField) + endId = getRelatedIdField(relation, relation.endField) } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { val properties = properties(variable, field.arguments) - val arguments = field.arguments.map { it.name to it }.toMap() + val arguments = field.arguments.associateBy { it.name } val (startNode, startWhere) = getRelationSelect(true, arguments) val (endNode, endWhere) = getRelationSelect(false, arguments) val relName = name(variable) @@ -118,4 +147,12 @@ class CreateRelationTypeHandler private constructor( .returning(relName.project(mapProjection).`as`(field.aliasOrName())) .build() } + + companion object { + private fun normalizeFieldName(relFieldName: String?, name: String): String { + // TODO b/c we need to stay backwards compatible this is not caml case but with underscore + //val filedName = normalizeName(relFieldName, name) + return "${relFieldName}_${name}" + } + } } diff --git a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/DeleteRelationHandler.kt b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/DeleteRelationHandler.kt index 6dd47689..e5787f05 100644 --- a/core/src/main/kotlin/org/neo4j/graphql/handler/relation/DeleteRelationHandler.kt +++ b/core/src/main/kotlin/org/neo4j/graphql/handler/relation/DeleteRelationHandler.kt @@ -1,29 +1,29 @@ package org.neo4j.graphql.handler.relation import graphql.language.Field +import graphql.language.ImplementingTypeDefinition import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment -import graphql.schema.GraphQLFieldDefinition -import graphql.schema.GraphQLFieldsContainer +import graphql.schema.idl.TypeDefinitionRegistry import org.neo4j.cypherdsl.core.Statement import org.neo4j.cypherdsl.core.StatementBuilder -import org.neo4j.graphql.* +import org.neo4j.graphql.Cypher +import org.neo4j.graphql.SchemaConfig +import org.neo4j.graphql.aliasOrName /** * This class handles all the logic related to the deletion of relations starting from an existing node. * This includes the augmentation of the delete<Edge>-mutator and the related cypher generation */ -class DeleteRelationHandler private constructor( - type: GraphQLFieldsContainer, - relation: RelationshipInfo, - startId: RelationshipInfo.RelatedField, - endId: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition, - schemaConfig: SchemaConfig -) : BaseRelationHandler(type, relation, startId, endId, fieldDefinition, schemaConfig) { +class DeleteRelationHandler private constructor(schemaConfig: SchemaConfig) : BaseRelationHandler("delete", schemaConfig) { + + class Factory(schemaConfig: SchemaConfig, + typeDefinitionRegistry: TypeDefinitionRegistry, + neo4jTypeDefinitionRegistry: TypeDefinitionRegistry + ) : BaseRelationFactory("delete", schemaConfig, typeDefinitionRegistry, neo4jTypeDefinitionRegistry) { + + override fun augmentType(type: ImplementingTypeDefinition<*>) { - class Factory(schemaConfig: SchemaConfig) : BaseRelationFactory("delete", schemaConfig) { - override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) { if (!canHandleType(type)) { return } @@ -31,24 +31,16 @@ class DeleteRelationHandler private constructor( .filter { canHandleField(it) } .mapNotNull { targetField -> buildFieldDefinition(type, targetField, true) - ?.let { builder -> buildingEnv.addMutationField(builder.build()) } + ?.let { builder -> addMutationField(builder.build()) } } } - override fun createDataFetcher( - sourceType: GraphQLFieldsContainer, - relation: RelationshipInfo, - startIdField: RelationshipInfo.RelatedField, - endIdField: RelationshipInfo.RelatedField, - fieldDefinition: GraphQLFieldDefinition - ): DataFetcher { - return DeleteRelationHandler(sourceType, relation, startIdField, endIdField, fieldDefinition, schemaConfig) - } + override fun createDataFetcher(): DataFetcher = DeleteRelationHandler(schemaConfig) } override fun generateCypher(variable: String, field: Field, env: DataFetchingEnvironment): Statement { - val arguments = field.arguments.map { it.name to it }.toMap() + val arguments = field.arguments.associateBy { it.name } val (startNode, startWhere) = getRelationSelect(true, arguments) val (endNode, endWhere) = getRelationSelect(false, arguments) val relName = org.neo4j.cypherdsl.core.Cypher.name("r") diff --git a/core/src/test/kotlin/org/neo4j/graphql/utils/GraphQLSchemaTestSuite.kt b/core/src/test/kotlin/org/neo4j/graphql/utils/GraphQLSchemaTestSuite.kt index b47efed4..308dbf90 100644 --- a/core/src/test/kotlin/org/neo4j/graphql/utils/GraphQLSchemaTestSuite.kt +++ b/core/src/test/kotlin/org/neo4j/graphql/utils/GraphQLSchemaTestSuite.kt @@ -11,6 +11,7 @@ import graphql.schema.idl.SchemaGenerator import graphql.schema.idl.SchemaParser import graphql.schema.idl.SchemaPrinter import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Assumptions import org.junit.jupiter.api.DynamicNode import org.junit.jupiter.api.DynamicTest @@ -60,6 +61,9 @@ class GraphQLSchemaTestSuite(fileName: String) : AsciiDocTestSuite( if (ignore) { Assumptions.assumeFalse(true, e.message) } else { + if (augmentedSchema == null) { + Assertions.fail(e) + } val actualSchema = SCHEMA_PRINTER.print(augmentedSchema) targetSchemaBlock.adjustedCode = actualSchema + "\n" + // this is added since the SCHEMA_PRINTER is not able to print global directives diff --git a/core/src/test/resources/augmentation-tests.adoc b/core/src/test/resources/augmentation-tests.adoc index 475dc525..96cced25 100644 --- a/core/src/test/resources/augmentation-tests.adoc +++ b/core/src/test/resources/augmentation-tests.adoc @@ -894,8 +894,8 @@ input _Neo4jLocalTimeInput { } input _Neo4jPointDistanceFilter { - distance: Float - point: _Neo4jPointInput + distance: Float! + point: _Neo4jPointInput! } input _Neo4jPointInput { diff --git a/examples/graphql-spring-boot/src/main/kotlin/org/neo4j/graphql/examples/graphqlspringboot/config/Neo4jConfiguration.kt b/examples/graphql-spring-boot/src/main/kotlin/org/neo4j/graphql/examples/graphqlspringboot/config/Neo4jConfiguration.kt index e41cbde9..b9991f1d 100644 --- a/examples/graphql-spring-boot/src/main/kotlin/org/neo4j/graphql/examples/graphqlspringboot/config/Neo4jConfiguration.kt +++ b/examples/graphql-spring-boot/src/main/kotlin/org/neo4j/graphql/examples/graphqlspringboot/config/Neo4jConfiguration.kt @@ -1,6 +1,5 @@ package org.neo4j.graphql.examples.graphqlspringboot.config -import graphql.language.VariableReference import graphql.schema.* import org.neo4j.driver.Driver import org.neo4j.graphql.Cypher