Skip to content

Migrate filter and field args into a where #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions core/src/main/kotlin/org/neo4j/graphql/BuildingEnv.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import org.neo4j.graphql.handler.projection.ProjectionBase

class BuildingEnv(
val types: MutableMap<String, GraphQLNamedType>,
private val sourceSchema: GraphQLSchema
private val sourceSchema: GraphQLSchema,
val schemaConfig: SchemaConfig
) {

private val typesForRelation = types.values
Expand Down Expand Up @@ -77,7 +78,7 @@ class BuildingEnv(
}

fun addFilterType(type: GraphQLFieldsContainer, createdTypes: MutableSet<String> = mutableSetOf()): String {
val filterName = "_${type.name}Filter"
val filterName = if (schemaConfig.useWhereFilter) type.name + "Where" else "_${type.name}Filter"
if (createdTypes.contains(filterName)) {
return filterName
}
Expand Down Expand Up @@ -138,7 +139,7 @@ class BuildingEnv(
?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type")
}
val sortTypeName = addSortInputType(type)
val optionsTypeBuilder = GraphQLInputObjectType.newInputObject().name(optionsName)
val optionsTypeBuilder = GraphQLInputObjectType.newInputObject().name(optionsName)
if (sortTypeName != null) {
optionsTypeBuilder.field(GraphQLInputObjectField.newInputObjectField()
.name(ProjectionBase.SORT)
Expand All @@ -147,10 +148,10 @@ class BuildingEnv(
.build())
}
optionsTypeBuilder.field(GraphQLInputObjectField.newInputObjectField()
.name(ProjectionBase.LIMIT)
.type(Scalars.GraphQLInt)
.description("Defines the maximum amount of records returned")
.build())
.name(ProjectionBase.LIMIT)
.type(Scalars.GraphQLInt)
.description("Defines the maximum amount of records returned")
.build())
.field(GraphQLInputObjectField.newInputObjectField()
.name(ProjectionBase.SKIP)
.type(Scalars.GraphQLInt)
Expand All @@ -169,7 +170,7 @@ class BuildingEnv(
?: throw IllegalStateException("Ordering type $type.name is already defined but not an input type")
}
val relevantFields = type.relevantFields()
if (relevantFields.isEmpty()){
if (relevantFields.isEmpty()) {
return null
}
val builder = GraphQLInputObjectType.newInputObject()
Expand Down
15 changes: 8 additions & 7 deletions core/src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ object SchemaBuilder {

private fun augmentSchema(sourceSchema: GraphQLSchema, handler: List<AugmentationHandler>, schemaConfig: SchemaConfig): GraphQLSchema {
val types = sourceSchema.typeMap.toMutableMap()
val env = BuildingEnv(types, sourceSchema)
val env = BuildingEnv(types, sourceSchema, schemaConfig)
val queryTypeName = sourceSchema.queryTypeName()
val mutationTypeName = sourceSchema.mutationTypeName()
val subscriptionTypeName = sourceSchema.subscriptionTypeName()
Expand All @@ -157,11 +157,11 @@ object SchemaBuilder {
builder.clearFields().clearInterfaces()
// to prevent duplicated types in schema
sourceType.interfaces.forEach { builder.withInterface(GraphQLTypeReference(it.name)) }
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env, schemaConfig)) }
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env)) }
}
sourceType is GraphQLInterfaceType -> sourceType.transform { builder ->
builder.clearFields()
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env, schemaConfig)) }
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env)) }
}
else -> sourceType
}
Expand All @@ -177,7 +177,7 @@ object SchemaBuilder {
.build()
}

private fun enhanceRelations(fd: GraphQLFieldDefinition, env: BuildingEnv, schemaConfig: SchemaConfig): GraphQLFieldDefinition {
private fun enhanceRelations(fd: GraphQLFieldDefinition, env: BuildingEnv): GraphQLFieldDefinition {
return fd.transform { fieldBuilder ->
// to prevent duplicated types in schema
fieldBuilder.type(fd.type.ref() as GraphQLOutputType)
Expand All @@ -188,7 +188,7 @@ object SchemaBuilder {

val fieldType = fd.type.inner() as? GraphQLFieldsContainer ?: return@transform

if (schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE){
if (env.schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE) {

val optionsTypeName = env.addOptions(fieldType)
val optionsType = GraphQLTypeReference(optionsTypeName)
Expand All @@ -212,9 +212,10 @@ object SchemaBuilder {

}

if (schemaConfig.query.enabled && !schemaConfig.query.exclude.contains(fieldType.name) && fd.getArgument(ProjectionBase.FILTER) == null) {
val filterFieldName = if (env.schemaConfig.useWhereFilter) ProjectionBase.WHERE else ProjectionBase.FILTER
if (env.schemaConfig.query.enabled && !env.schemaConfig.query.exclude.contains(fieldType.name) && fd.getArgument(filterFieldName) == null) {
val filterTypeName = env.addFilterType(fieldType)
fieldBuilder.argument(input(ProjectionBase.FILTER, GraphQLTypeReference(filterTypeName)))
fieldBuilder.argument(input(filterFieldName, GraphQLTypeReference(filterTypeName)))
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions core/src/main/kotlin/org/neo4j/graphql/SchemaConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@ package org.neo4j.graphql
data class SchemaConfig @JvmOverloads constructor(
val query: CRUDConfig = CRUDConfig(),
val mutation: CRUDConfig = CRUDConfig(),

/**
* if true, the top level fields of the Query-type will be capitalized
*/
val capitalizeQueryFields: Boolean = false,

/**
* Defines the way the input for queries and mutations are generated
*/
val queryOptionStyle: InputStyle = InputStyle.ARGUMENT_PER_FIELD,

/**
* if enabled the `filter` argument will be named `where` and the input type will be named `<typeName>Where`.
* additionally the separated filter arguments will no longer be generated.
*/
val useWhereFilter: Boolean = false,
) {
data class CRUDConfig(val enabled: Boolean = true, val exclude: List<String> = emptyList())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ import org.neo4j.cypherdsl.core.Statement
import org.neo4j.cypherdsl.core.renderer.Configuration
import org.neo4j.cypherdsl.core.renderer.Renderer
import org.neo4j.graphql.Cypher
import org.neo4j.graphql.SchemaConfig
import org.neo4j.graphql.aliasOrName
import org.neo4j.graphql.handler.projection.ProjectionBase

/**
* The is a base class for the implementation of graphql data fetcher used in this project
*/
abstract class BaseDataFetcher(val fieldDefinition: GraphQLFieldDefinition) : ProjectionBase(), DataFetcher<Cypher> {
abstract class BaseDataFetcher(val fieldDefinition: GraphQLFieldDefinition, schemaConfig: SchemaConfig) : ProjectionBase(schemaConfig), DataFetcher<Cypher> {

override fun get(env: DataFetchingEnvironment?): Cypher {
val field = env?.mergedField?.singleField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ import org.neo4j.graphql.*
*/
abstract class BaseDataFetcherForContainer(
val type: GraphQLFieldsContainer,
fieldDefinition: GraphQLFieldDefinition
) : BaseDataFetcher(fieldDefinition) {
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig
) : BaseDataFetcher(fieldDefinition, schemaConfig) {

val propertyFields: MutableMap<String, (Any) -> List<PropertyAccessor>?> = mutableMapOf()
val defaultFields: MutableMap<String, Any> = mutableMapOf()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import org.neo4j.graphql.*
*/
class CreateTypeHandler private constructor(
type: GraphQLFieldsContainer,
fieldDefinition: GraphQLFieldDefinition
) : BaseDataFetcherForContainer(type, fieldDefinition) {
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig
) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) {

class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) {
override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) {
Expand Down Expand Up @@ -41,7 +42,7 @@ class CreateTypeHandler private constructor(
return null
}
return when {
fieldDefinition.name == "create${type.name}" -> CreateTypeHandler(type, fieldDefinition)
fieldDefinition.name == "create${type.name}" -> CreateTypeHandler(type, fieldDefinition, schemaConfig)
else -> null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@ class CypherDirectiveHandler(
private val type: GraphQLFieldsContainer?,
private val isQuery: Boolean,
private val cypherDirective: CypherDirective,
fieldDefinition: GraphQLFieldDefinition)
: BaseDataFetcher(fieldDefinition) {
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig)
: BaseDataFetcher(fieldDefinition, schemaConfig) {

class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) {

override fun createDataFetcher(operationType: OperationType, fieldDefinition: GraphQLFieldDefinition): DataFetcher<Cypher>? {
val cypherDirective = fieldDefinition.cypherDirective() ?: return null
val type = fieldDefinition.type.inner() as? GraphQLFieldsContainer
val isQuery = operationType == OperationType.QUERY
return CypherDirectiveHandler(type, isQuery, cypherDirective, fieldDefinition)
return CypherDirectiveHandler(type, isQuery, cypherDirective, fieldDefinition, schemaConfig)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ class DeleteHandler private constructor(
type: GraphQLFieldsContainer,
private val idField: GraphQLFieldDefinition,
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig,
private val isRelation: Boolean = type.isRelationType()
) : BaseDataFetcherForContainer(type, fieldDefinition) {
) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) {

class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) {
override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) {
Expand Down Expand Up @@ -48,7 +49,7 @@ class DeleteHandler private constructor(
}
val idField = type.getIdField() ?: return null
return when (fieldDefinition.name) {
"delete${type.name}" -> DeleteHandler(type, idField, fieldDefinition)
"delete${type.name}" -> DeleteHandler(type, idField, fieldDefinition, schemaConfig)
else -> null
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class MergeOrUpdateHandler private constructor(
private val merge: Boolean,
private val idField: GraphQLFieldDefinition,
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig,
private val isRelation: Boolean = type.isRelationType()
) : BaseDataFetcherForContainer(type, fieldDefinition) {
) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) {

class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) {
override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) {
Expand Down Expand Up @@ -55,8 +56,8 @@ class MergeOrUpdateHandler private constructor(
}
val idField = type.getIdField() ?: return null
return when (fieldDefinition.name) {
"merge${type.name}" -> MergeOrUpdateHandler(type, true, idField, fieldDefinition)
"update${type.name}" -> MergeOrUpdateHandler(type, false, idField, fieldDefinition)
"merge${type.name}" -> MergeOrUpdateHandler(type, true, idField, fieldDefinition, schemaConfig)
"update${type.name}" -> MergeOrUpdateHandler(type, false, idField, fieldDefinition, schemaConfig)
else -> null
}
}
Expand Down Expand Up @@ -107,7 +108,7 @@ class MergeOrUpdateHandler private constructor(
}
}
val properties = properties(variable, field.arguments)
val mapProjection = projectFields(propertyContainer,field, type, env)
val mapProjection = projectFields(propertyContainer, field, type, env)
val update: OngoingMatchAndUpdate = select
.mutate(propertyContainer, org.neo4j.cypherdsl.core.Cypher.mapOf(*properties))

Expand Down
20 changes: 12 additions & 8 deletions core/src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ import org.neo4j.graphql.handler.filter.OptimizedFilterHandler
*/
class QueryHandler private constructor(
type: GraphQLFieldsContainer,
fieldDefinition: GraphQLFieldDefinition)
: BaseDataFetcherForContainer(type, fieldDefinition) {
fieldDefinition: GraphQLFieldDefinition,
schemaConfig: SchemaConfig
) : BaseDataFetcherForContainer(type, fieldDefinition, schemaConfig) {

class Factory(schemaConfig: SchemaConfig) : AugmentationHandler(schemaConfig) {
override fun augmentType(type: GraphQLFieldsContainer, buildingEnv: BuildingEnv) {
Expand All @@ -25,15 +26,18 @@ class QueryHandler private constructor(
val typeName = type.name
val relevantFields = getRelevantFields(type)

// TODO not just generate the input type but use it as well
buildingEnv.addInputType("_${typeName}Input", type.relevantFields())
val filterTypeName = buildingEnv.addFilterType(type)
val arguments = if (schemaConfig.useWhereFilter) {
listOf(input(WHERE, GraphQLTypeReference(filterTypeName)))
} else {
buildingEnv.getInputValueDefinitions(relevantFields, { true }) +
input(FILTER, GraphQLTypeReference(filterTypeName))
}

val builder = GraphQLFieldDefinition
.newFieldDefinition()
.name(if (schemaConfig.capitalizeQueryFields) typeName else typeName.decapitalize())
.arguments(buildingEnv.getInputValueDefinitions(relevantFields) { true })
.argument(input(FILTER, GraphQLTypeReference(filterTypeName)))
.arguments(arguments)
.type(GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLTypeReference(type.name)))))

if (schemaConfig.queryOptionStyle == SchemaConfig.InputStyle.INPUT_TYPE) {
Expand Down Expand Up @@ -68,7 +72,7 @@ class QueryHandler private constructor(
if (!canHandle(type)) {
return null
}
return QueryHandler(type, fieldDefinition)
return QueryHandler(type, fieldDefinition, schemaConfig)
}

private fun canHandle(type: GraphQLFieldsContainer): Boolean {
Expand Down Expand Up @@ -102,7 +106,7 @@ class QueryHandler private constructor(

val ongoingReading = if ((env.getContext() as? QueryContext)?.optimizedQuery?.contains(QueryContext.OptimizationStrategy.FILTER_AS_MATCH) == true) {

OptimizedFilterHandler(type).generateFilterQuery(variable, fieldDefinition, field, match, propertyContainer, env.variables)
OptimizedFilterHandler(type, schemaConfig).generateFilterQuery(variable, fieldDefinition, field, match, propertyContainer, env.variables)

} else {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ typealias ConditionBuilder = (ExposesWith) -> OrderableOngoingReadingAndWithWith
* If this handler cannot generate an optimization for the passed filter, an [OptimizedQueryException] will be
* thrown, so the calling site can fall back to the non-optimized logic
*/
class OptimizedFilterHandler(val type: GraphQLFieldsContainer) : ProjectionBase() {
class OptimizedFilterHandler(val type: GraphQLFieldsContainer, schemaConfig: SchemaConfig) : ProjectionBase(schemaConfig) {

fun generateFilterQuery(variable: String, fieldDefinition: GraphQLFieldDefinition, field: Field, readingWithoutWhere: OngoingReadingWithoutWhere, rootNode: PropertyContainer, variables: Map<String, Any>): OngoingReading {
if (type.isRelationType()) {
Expand All @@ -41,23 +41,23 @@ class OptimizedFilterHandler(val type: GraphQLFieldsContainer) : ProjectionBase(

var ongoingReading: OngoingReading? = null

val filteredArguments = field.arguments.filterNot { SPECIAL_FIELDS.contains(it.name) }
if (filteredArguments.isNotEmpty()) {
val parsedQuery = QueryParser.parseArguments(filteredArguments, fieldDefinition, type, variables)
val condition = handleQuery(variable, "", rootNode, parsedQuery, type, variables)
ongoingReading = readingWithoutWhere.where(condition)
if (!schemaConfig.useWhereFilter) {
val filteredArguments = field.arguments.filterNot { SPECIAL_FIELDS.contains(it.name) }
if (filteredArguments.isNotEmpty()) {
val parsedQuery = QueryParser.parseArguments(filteredArguments, fieldDefinition, type, variables)
val condition = handleQuery(variable, "", rootNode, parsedQuery, type, variables)
ongoingReading = readingWithoutWhere.where(condition)
}
}
for (argument in field.arguments) {
if (argument.name == FILTER) {
return field.arguments.find { filterFieldName() == it.name }
?.let { argument ->
val parsedQuery = parseFilter(argument.value as ObjectValue, type, variables)
ongoingReading = NestingLevelHandler(parsedQuery, false, rootNode, variable, ongoingReading
NestingLevelHandler(parsedQuery, false, rootNode, variable, ongoingReading
?: readingWithoutWhere,
type, argument.value, linkedSetOf(rootNode.requiredSymbolicName), variables)
.parseFilter()
}
}

return ongoingReading ?: readingWithoutWhere
?: readingWithoutWhere
}

/**
Expand Down
Loading