Skip to content

Commit

Permalink
Migrate filter and field args into a where (#220)
Browse files Browse the repository at this point in the history
These changes are made in order to harmonize with the API of the js version of this library.

resolves #181
  • Loading branch information
Andy2003 authored May 4, 2021
1 parent 9acc31e commit 823fe6b
Show file tree
Hide file tree
Showing 22 changed files with 1,138 additions and 147 deletions.
17 changes: 9 additions & 8 deletions core/src/main/kotlin/org/neo4j/graphql/BuildingEnv.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import org.neo4j.graphql.handler.projection.ProjectionBase

class BuildingEnv(
val types: MutableMap<String, GraphQLNamedType>,
private val sourceSchema: GraphQLSchema
private val sourceSchema: GraphQLSchema,
val schemaConfig: SchemaConfig
) {

private val typesForRelation = types.values
Expand Down Expand Up @@ -77,7 +78,7 @@ class BuildingEnv(
}

fun addFilterType(type: GraphQLFieldsContainer, createdTypes: MutableSet<String> = mutableSetOf()): String {
val filterName = "_${type.name}Filter"
val filterName = if (schemaConfig.useWhereFilter) type.name + "Where" else "_${type.name}Filter"
if (createdTypes.contains(filterName)) {
return filterName
}
Expand Down Expand Up @@ -138,7 +139,7 @@ class BuildingEnv(
?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type")
}
val sortTypeName = addSortInputType(type)
val optionsTypeBuilder = GraphQLInputObjectType.newInputObject().name(optionsName)
val optionsTypeBuilder = GraphQLInputObjectType.newInputObject().name(optionsName)
if (sortTypeName != null) {
optionsTypeBuilder.field(GraphQLInputObjectField.newInputObjectField()
.name(ProjectionBase.SORT)
Expand All @@ -147,10 +148,10 @@ class BuildingEnv(
.build())
}
optionsTypeBuilder.field(GraphQLInputObjectField.newInputObjectField()
.name(ProjectionBase.LIMIT)
.type(Scalars.GraphQLInt)
.description("Defines the maximum amount of records returned")
.build())
.name(ProjectionBase.LIMIT)
.type(Scalars.GraphQLInt)
.description("Defines the maximum amount of records returned")
.build())
.field(GraphQLInputObjectField.newInputObjectField()
.name(ProjectionBase.SKIP)
.type(Scalars.GraphQLInt)
Expand All @@ -169,7 +170,7 @@ class BuildingEnv(
?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type")
}
val relevantFields = type.relevantFields()
if (relevantFields.isEmpty()){
if (relevantFields.isEmpty()) {
return null
}
val builder = GraphQLInputObjectType.newInputObject()
Expand Down
15 changes: 8 additions & 7 deletions core/src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ object SchemaBuilder {

private fun augmentSchema(sourceSchema: GraphQLSchema, handler: List<AugmentationHandler>, schemaConfig: SchemaConfig): GraphQLSchema {
val types = sourceSchema.typeMap.toMutableMap()
val env = BuildingEnv(types, sourceSchema)
val env = BuildingEnv(types, sourceSchema, schemaConfig)
val queryTypeName = sourceSchema.queryTypeName()
val mutationTypeName = sourceSchema.mutationTypeName()
val subscriptionTypeName = sourceSchema.subscriptionTypeName()
Expand All @@ -157,11 +157,11 @@ object SchemaBuilder {
builder.clearFields().clearInterfaces()
// to prevent duplicated types in schema
sourceType.interfaces.forEach { builder.withInterface(GraphQLTypeReference(it.name)) }
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env, schemaConfig)) }
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env)) }
}
sourceType is GraphQLInterfaceType -> sourceType.transform { builder ->
builder.clearFields()
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env, schemaConfig)) }
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env)) }
}
else -> sourceType
}
Expand All @@ -177,7 +177,7 @@ object SchemaBuilder {
.build()
}

private fun enhanceRelations(fd: GraphQLFieldDefinition, env: BuildingEnv, schemaConfig: SchemaConfig): GraphQLFieldDefinition {
private fun enhanceRelations(fd: GraphQLFieldDefinition, env: BuildingEnv): GraphQLFieldDefinition {
return fd.transform { fieldBuilder ->
// to prevent duplicated types in schema
fieldBuilder.type(fd.type.ref() as GraphQLOutputType)
Expand All @@ -188,7 +188,7 @@ object SchemaBuilder {

val fieldType = fd.type.inner() as? GraphQLFieldsContainer ?: return@transform

if (schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE){
if (env.schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE) {

val optionsTypeName = env.addOptions(fieldType)
val optionsType = GraphQLTypeReference(optionsTypeName)
Expand All @@ -212,9 +212,10 @@ object SchemaBuilder {

}

if (schemaConfig.query.enabled && !schemaConfig.query.exclude.contains(fieldType.name) && fd.getArgument(ProjectionBase.FILTER) == null) {
val filterFieldName = if (env.schemaConfig.useWhereFilter) ProjectionBase.WHERE else ProjectionBase.FILTER
if (env.schemaConfig.query.enabled && !env.schemaConfig.query.exclude.contains(fieldType.name) && fd.getArgument(filterFieldName) == null) {
val filterTypeName = env.addFilterType(fieldType)
fieldBuilder.argument(input(ProjectionBase.FILTER, GraphQLTypeReference(filterTypeName)))
fieldBuilder.argument(input(filterFieldName, GraphQLTypeReference(filterTypeName)))
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions core/src/main/kotlin/org/neo4j/graphql/SchemaConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@ package org.neo4j.graphql
data class SchemaConfig @JvmOverloads constructor(
val query: CRUDConfig = CRUDConfig(),
val mutation: CRUDConfig = CRUDConfig(),

/**
* if true, the top level fields of the Query-type will be capitalized
*/
val capitalizeQueryFields: Boolean = false,

/**
* Defines the way the input for queries and mutations are generated
*/
val queryOptionStyle: InputStyle = InputStyle.ARGUMENT_PER_FIELD,

/**
* if enabled the `filter` argument will be named `where` and the input type will be named `<typeName>Where`.
* additionally the separated filter arguments will no longer be generated.
*/
val useWhereFilter: Boolean = false,
) {
data class CRUDConfig(val enabled: Boolean = true, val exclude: List<String> = emptyList())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ import org.neo4j.cypherdsl.core.Statement
import org.neo4j.cypherdsl.core.renderer.Configuration
import org.neo4j.cypherdsl.core.renderer.Renderer
import org.neo4j.graphql.Cypher
import org.neo4j.graphql.SchemaConfig
import org.neo4j.graphql.aliasOrName
import org.neo4j.graphql.handler.projection.ProjectionBase

/**
* The is a base class for the implementation of graphql data fetcher used in this project
*/
abstract class BaseDataFetcher(val fieldDefinition: GraphQLFieldDefinition) : ProjectionBase(), DataFetcher<Cypher> {
abstract class BaseDataFetcher(val fieldDefinition: GraphQLFieldDefinition, schemaConfig: SchemaConfig) : ProjectionBase(schemaConfig), DataFetcher<Cypher> {

override fun get(env: DataFetchingEnvironment?): Cypher {
val field = env?.mergedField?.singleField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ import org.neo4j.graphql.*
*/
abstract class BaseDataFetcherForContainer(
val type: GraphQLFieldsContainer,
fieldDefinition: GraphQLFieldDefinition
) : BaseDataFetcher(fieldDefinition) {
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig
) : BaseDataFetcher(fieldDefinition, schemaConfig) {

val propertyFields: MutableMap<String, (Any) -> List<PropertyAccessor>?> = mutableMapOf()
val defaultFields: MutableMap<String, Any> = mutableMapOf()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import org.neo4j.graphql.*
*/
class CreateTypeHandler private constructor(
type: GraphQLFieldsContainer,
fieldDefinition: GraphQLFieldDefinition
) : BaseDataFetcherForContainer(type, fieldDefinition) {
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig
) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) {

class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) {
override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) {
Expand Down Expand Up @@ -41,7 +42,7 @@ class CreateTypeHandler private constructor(
return null
}
return when {
fieldDefinition.name == "create${type.name}" -> CreateTypeHandler(type, fieldDefinition)
fieldDefinition.name == "create${type.name}" -> CreateTypeHandler(type, fieldDefinition, schemaConfig)
else -> null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@ class CypherDirectiveHandler(
private val type: GraphQLFieldsContainer?,
private val isQuery: Boolean,
private val cypherDirective: CypherDirective,
fieldDefinition: GraphQLFieldDefinition)
: BaseDataFetcher(fieldDefinition) {
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig)
: BaseDataFetcher(fieldDefinition, schemaConfig) {

class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) {

override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher<Cypher>? {
val cypherDirective = fieldDefinition.cypherDirective() ?: return null
val type = fieldDefinition.type.inner() as? GraphQLFieldsContainer
val isQuery = operationType == OperationType.QUERY
return CypherDirectiveHandler(type, isQuery, cypherDirective, fieldDefinition)
return CypherDirectiveHandler(type, isQuery, cypherDirective, fieldDefinition, schemaConfig)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ class DeleteHandler private constructor(
type: GraphQLFieldsContainer,
private val idField: GraphQLFieldDefinition,
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig,
private val isRelation: Boolean = type.isRelationType()
) : BaseDataFetcherForContainer(type, fieldDefinition) {
) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) {

class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) {
override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) {
Expand Down Expand Up @@ -48,7 +49,7 @@ class DeleteHandler private constructor(
}
val idField = type.getIdField() ?: return null
return when (fieldDefinition.name) {
"delete${type.name}" -> DeleteHandler(type, idField, fieldDefinition)
"delete${type.name}" -> DeleteHandler(type, idField, fieldDefinition, schemaConfig)
else -> null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class MergeOrUpdateHandler private constructor(
private val merge: Boolean,
private val idField: GraphQLFieldDefinition,
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig,
private val isRelation: Boolean = type.isRelationType()
) : BaseDataFetcherForContainer(type, fieldDefinition) {
) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) {

class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) {
override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) {
Expand Down Expand Up @@ -55,8 +56,8 @@ class MergeOrUpdateHandler private constructor(
}
val idField = type.getIdField() ?: return null
return when (fieldDefinition.name) {
"merge${type.name}" -> MergeOrUpdateHandler(type, true, idField, fieldDefinition)
"update${type.name}" -> MergeOrUpdateHandler(type, false, idField, fieldDefinition)
"merge${type.name}" -> MergeOrUpdateHandler(type, true, idField, fieldDefinition, schemaConfig)
"update${type.name}" -> MergeOrUpdateHandler(type, false, idField, fieldDefinition, schemaConfig)
else -> null
}
}
Expand Down Expand Up @@ -107,7 +108,7 @@ class MergeOrUpdateHandler private constructor(
}
}
val properties = properties(variable, field.arguments)
val mapProjection = projectFields(propertyContainer,field, type, env)
val mapProjection = projectFields(propertyContainer, field, type, env)
val update: OngoingMatchAndUpdate = select
.mutate(propertyContainer, org.neo4j.cypherdsl.core.Cypher.mapOf(*properties))

Expand Down
20 changes: 12 additions & 8 deletions core/src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ import org.neo4j.graphql.handler.filter.OptimizedFilterHandler
*/
class QueryHandler private constructor(
type: GraphQLFieldsContainer,
fieldDefinition: GraphQLFieldDefinition)
: BaseDataFetcherForContainer(type, fieldDefinition) {
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig
) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) {

class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) {
override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) {
Expand All @@ -25,15 +26,18 @@ class QueryHandler private constructor(
val typeName = type.name
val relevantFields = getRelevantFields(type)

// TODO not just generate the input type but use it as well
buildingEnv.addInputType("_${typeName}Input", type.relevantFields())
val filterTypeName = buildingEnv.addFilterType(type)
val arguments = if (schemaConfig.useWhereFilter) {
listOf(input(WHERE, GraphQLTypeReference(filterTypeName)))
} else {
buildingEnv.getInputValueDefinitions(relevantFields, { true }) +
input(FILTER, GraphQLTypeReference(filterTypeName))
}

val builder = GraphQLFieldDefinition
.newFieldDefinition()
.name(if (schemaConfig.capitalizeQueryFields) typeName else typeName.decapitalize())
.arguments(buildingEnv.getInputValueDefinitions(relevantFields) { true })
.argument(input(FILTER, GraphQLTypeReference(filterTypeName)))
.arguments(arguments)
.type(GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLTypeReference(type.name)))))

if (schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE) {
Expand Down Expand Up @@ -68,7 +72,7 @@ class QueryHandler private constructor(
if (!canHandle(type)) {
return null
}
return QueryHandler(type, fieldDefinition)
return QueryHandler(type, fieldDefinition, schemaConfig)
}

private fun canHandle(type: GraphQLFieldsContainer): Boolean {
Expand Down Expand Up @@ -102,7 +106,7 @@ class QueryHandler private constructor(

val ongoingReading = if ((env.getContext() as? QueryContext)?.optimizedQuery?.contains(QueryContext.OptimizationStrategy.FILTER_AS_MATCH) == true) {

OptimizedFilterHandler(type).generateFilterQuery(variable, fieldDefinition, field, match, propertyContainer, env.variables)
OptimizedFilterHandler(type, schemaConfig).generateFilterQuery(variable, fieldDefinition, field, match, propertyContainer, env.variables)

} else {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ typealias ConditionBuilder = (ExposesWith) -> OrderableOngoingReadingAndWithWith
* If this handler cannot generate an optimization for the passed filter, an [OptimizedQueryException] will be
* thrown, so the calling site can fall back to the non-optimized logic
*/
class OptimizedFilterHandler(val type: GraphQLFieldsContainer) : ProjectionBase() {
class OptimizedFilterHandler(val type: GraphQLFieldsContainer, schemaConfig: SchemaConfig) : ProjectionBase(schemaConfig) {

fun generateFilterQuery(variable: String, fieldDefinition: GraphQLFieldDefinition, field: Field, readingWithoutWhere: OngoingReadingWithoutWhere, rootNode: PropertyContainer, variables: Map<String, Any>): OngoingReading {
if (type.isRelationType()) {
Expand All @@ -41,23 +41,23 @@ class OptimizedFilterHandler(val type: GraphQLFieldsContainer) : ProjectionBase(

var ongoingReading: OngoingReading? = null

val filteredArguments = field.arguments.filterNot { SPECIAL_FIELDS.contains(it.name) }
if (filteredArguments.isNotEmpty()) {
val parsedQuery = QueryParser.parseArguments(filteredArguments, fieldDefinition, type, variables)
val condition = handleQuery(variable, "", rootNode, parsedQuery, type, variables)
ongoingReading = readingWithoutWhere.where(condition)
if (!schemaConfig.useWhereFilter) {
val filteredArguments = field.arguments.filterNot { SPECIAL_FIELDS.contains(it.name) }
if (filteredArguments.isNotEmpty()) {
val parsedQuery = QueryParser.parseArguments(filteredArguments, fieldDefinition, type, variables)
val condition = handleQuery(variable, "", rootNode, parsedQuery, type, variables)
ongoingReading = readingWithoutWhere.where(condition)
}
}
for (argument in field.arguments) {
if (argument.name == FILTER) {
return field.arguments.find { filterFieldName() == it.name }
?.let { argument ->
val parsedQuery = parseFilter(argument.value as ObjectValue, type, variables)
ongoingReading = NestingLevelHandler(parsedQuery, false, rootNode, variable, ongoingReading
NestingLevelHandler(parsedQuery, false, rootNode, variable, ongoingReading
?: readingWithoutWhere,
type, argument.value, linkedSetOf(rootNode.requiredSymbolicName), variables)
.parseFilter()
}
}

return ongoingReading ?: readingWithoutWhere
?: readingWithoutWhere
}

/**
Expand Down
Loading

0 comments on commit 823fe6b

Please sign in to comment.