Skip to content

Commit 891365c

Browse files
committed
Nested sorting on fields (GH #3)
This commit add the following features: * sorting on multiple properties * sorting on fields * augmentation-tests provides a diff on failure that can be viewed in IntelliJ Bugfix: * The formatted field for temporal is now a string rather than an object
1 parent f8e26b3 commit 891365c

File tree

6 files changed

+280
-99
lines changed

6 files changed

+280
-99
lines changed

src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import org.neo4j.graphql.handler.projection.ProjectionBase
1313
import org.neo4j.graphql.handler.relation.CreateRelationHandler
1414
import org.neo4j.graphql.handler.relation.CreateRelationTypeHandler
1515
import org.neo4j.graphql.handler.relation.DeleteRelationHandler
16+
import java.util.concurrent.ConcurrentHashMap
1617

1718
object SchemaBuilder {
1819
private const val MUTATION = "Mutation"
@@ -74,7 +75,7 @@ object SchemaBuilder {
7475

7576
val handler = getHandler(config)
7677

77-
var targetSchema = augmentSchema(sourceSchema, handler)
78+
var targetSchema = augmentSchema(sourceSchema, handler, config)
7879
targetSchema = addDataFetcher(targetSchema, dataFetchingInterceptor, handler)
7980
return targetSchema
8081
}
@@ -99,8 +100,8 @@ object SchemaBuilder {
99100
return handler
100101
}
101102

102-
private fun augmentSchema(sourceSchema: GraphQLSchema, handler: List<AugmentationHandler>): GraphQLSchema {
103-
val types = sourceSchema.typeMap.toMutableMap()
103+
private fun augmentSchema(sourceSchema: GraphQLSchema, handler: List<AugmentationHandler>, config: SchemaConfig): GraphQLSchema {
104+
val types = sourceSchema.typeMap.toMap(ConcurrentHashMap())
104105
val env = BuildingEnv(types)
105106

106107
types.values
@@ -123,11 +124,11 @@ object SchemaBuilder {
123124
builder.clearFields().clearInterfaces()
124125
// to prevent duplicated types in schema
125126
sourceType.interfaces.forEach { builder.withInterface(GraphQLTypeReference(it.name)) }
126-
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f)) }
127+
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env, config)) }
127128
}
128129
sourceType is GraphQLInterfaceType -> sourceType.transform { builder ->
129130
builder.clearFields()
130-
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f)) }
131+
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env, config)) }
131132
}
132133
else -> sourceType
133134
}
@@ -142,27 +143,29 @@ object SchemaBuilder {
142143
.build()
143144
}
144145

145-
private fun enhanceRelations(fd: GraphQLFieldDefinition): GraphQLFieldDefinition {
146-
return fd.transform {
146+
private fun enhanceRelations(fd: GraphQLFieldDefinition, env: BuildingEnv, config: SchemaConfig): GraphQLFieldDefinition {
147+
return fd.transform { fieldBuilder ->
147148
// to prevent duplicated types in schema
148-
it.type(fd.type.ref() as GraphQLOutputType)
149+
fieldBuilder.type(fd.type.ref() as GraphQLOutputType)
149150

150151
if (!fd.isRelationship() || !fd.type.isList()) {
151152
return@transform
152153
}
153154

154155
if (fd.getArgument(ProjectionBase.FIRST) == null) {
155-
it.argument { a -> a.name(ProjectionBase.FIRST).type(Scalars.GraphQLInt) }
156+
fieldBuilder.argument { a -> a.name(ProjectionBase.FIRST).type(Scalars.GraphQLInt) }
156157
}
157158
if (fd.getArgument(ProjectionBase.OFFSET) == null) {
158-
it.argument { a -> a.name(ProjectionBase.OFFSET).type(Scalars.GraphQLInt) }
159+
fieldBuilder.argument { a -> a.name(ProjectionBase.OFFSET).type(Scalars.GraphQLInt) }
160+
}
161+
if (fd.getArgument(ProjectionBase.ORDER_BY) == null && fd.type.isList()) {
162+
(fd.type.inner() as? GraphQLFieldsContainer)?.let { fieldType ->
163+
env.addOrdering(fieldType)?.let { orderingTypeName ->
164+
val orderType = GraphQLList(GraphQLNonNull(GraphQLTypeReference(orderingTypeName)))
165+
fieldBuilder.argument { a -> a.name(ProjectionBase.ORDER_BY).type(orderType) }
166+
}
167+
}
159168
}
160-
// TODO implement ordering
161-
// if (fd.getArgument(ProjectionBase.ORDER_BY) == null) {
162-
// val typeName = fd.type.name()!!
163-
// val orderingType = addOrdering(typeName, metaProvider.getNodeType(typeName)!!.fieldDefinitions().filter { it.type.isScalar() })
164-
// it.argument { a -> a.name(ProjectionBase.ORDER_BY).type(orderingType) }
165-
// }
166169
}
167170
}
168171

src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class QueryHandler private constructor(
3434
.argument(input(OFFSET, Scalars.GraphQLInt))
3535
.type(GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLTypeReference(type.name)))))
3636
if (orderingTypeName != null) {
37-
builder.argument(input(ORDER_BY, GraphQLTypeReference(orderingTypeName)))
37+
val orderType = GraphQLList(GraphQLNonNull(GraphQLTypeReference(orderingTypeName)))
38+
builder.argument(input(ORDER_BY, orderType))
3839
}
3940
val def = builder.build()
4041
buildingEnv.addOperation(QUERY, def)

src/main/kotlin/org/neo4j/graphql/handler/projection/ProjectionBase.kt

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,30 @@ open class ProjectionBase {
1616
}
1717

1818
fun orderBy(variable: String, args: MutableList<Argument>): String {
19+
val values = getOrderByArgs(args)
20+
if (values.isEmpty()) {
21+
return ""
22+
}
23+
return " ORDER BY " + values.joinToString(", ", transform = { (property, direction) -> "$variable.$property $direction" })
24+
}
25+
26+
private fun getOrderByArgs(args: MutableList<Argument>): List<Pair<String, Sort>> {
1927
val arg = args.find { it.name == ORDER_BY }
20-
val values = arg?.value?.let { it ->
21-
when (it) {
22-
is ArrayValue -> it.values.map { it.toJavaValue().toString() }
23-
is EnumValue -> listOf(it.name)
24-
is StringValue -> listOf(it.value)
25-
else -> null
28+
return arg?.value
29+
?.let { it ->
30+
when (it) {
31+
is ArrayValue -> it.values.map { it.toJavaValue().toString() }
32+
is EnumValue -> listOf(it.name)
33+
is StringValue -> listOf(it.value)
34+
else -> null
35+
}
2636
}
27-
}
28-
@Suppress("SimplifiableCallChain")
29-
return if (values == null) ""
30-
else " ORDER BY " + values
31-
.map { it.split("_") }
32-
.map { "$variable.${it[0]} ${it[1].toUpperCase()}" }
33-
.joinToString(", ")
37+
?.map {
38+
val index = it.lastIndexOf('_')
39+
val property = it.substring(0, index)
40+
val direction = Sort.valueOf(it.substring(index + 1).toUpperCase())
41+
property to direction
42+
} ?: emptyList()
3443
}
3544

3645
fun where(variable: String, fieldDefinition: GraphQLFieldDefinition, type: GraphQLFieldsContainer, arguments: List<Argument>, field: Field): Cypher {
@@ -139,9 +148,8 @@ open class ProjectionBase {
139148
return predicates.values + defaults
140149
}
141150

142-
fun projectFields(variable: String, field: Field, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?): Cypher {
143-
val queries = projectSelectionSet(variable, field.selectionSet, nodeType, env, variableSuffix)
144-
151+
fun projectFields(variable: String, field: Field, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?, neo4jFieldsToPass: Set<String> = emptySet()): Cypher {
152+
val queries = projectSelection(variable, field.selectionSet.selections, nodeType, env, variableSuffix, neo4jFieldsToPass)
145153
@Suppress("SimplifiableCallChain")
146154
val projection = queries
147155
.map { it.query }
@@ -152,18 +160,18 @@ open class ProjectionBase {
152160
return Cypher("$variable $projection", params)
153161
}
154162

155-
private fun projectSelectionSet(variable: String, selectionSet: SelectionSet, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?): List<Cypher> {
163+
private fun projectSelection(variable: String, selection: List<Selection<*>>, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?, neo4jFieldsToPass: Set<String> = emptySet()): List<Cypher> {
156164
// TODO just render fragments on valid types (Labels) by using cypher like this:
157165
// apoc.map.mergeList([
158166
// a{.name},
159167
// CASE WHEN a:Location THEN a { .foo } ELSE {} END
160168
// ])
161169
var hasTypeName = false
162-
val projections = selectionSet.selections.flatMapTo(mutableListOf<Cypher>()) {
170+
val projections = selection.flatMapTo(mutableListOf<Cypher>()) {
163171
when (it) {
164172
is Field -> {
165173
hasTypeName = hasTypeName || (it.name == TYPE_NAME)
166-
listOf(projectField(variable, it, nodeType, env, variableSuffix))
174+
listOf(projectField(variable, it, nodeType, env, variableSuffix, neo4jFieldsToPass))
167175
}
168176
is InlineFragment -> projectInlineFragment(variable, it, env, variableSuffix)
169177
is FragmentSpread -> projectNamedFragments(variable, it, env, variableSuffix)
@@ -180,7 +188,7 @@ open class ProjectionBase {
180188
return projections
181189
}
182190

183-
private fun projectField(variable: String, field: Field, type: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?): Cypher {
191+
private fun projectField(variable: String, field: Field, type: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?, neo4jFieldsToPass: Set<String> = emptySet()): Cypher {
184192
if (field.name == TYPE_NAME) {
185193
return if (type.isRelationType()) {
186194
Cypher("${field.aliasOrName()}: '${type.name}'")
@@ -206,7 +214,11 @@ open class ProjectionBase {
206214
} ?: when {
207215
isObjectField -> {
208216
val patternComprehensions = if (fieldDefinition.isNeo4jType()) {
209-
projectNeo4jObjectType(variable, field)
217+
if (neo4jFieldsToPass.contains(fieldDefinition.innerName())) {
218+
Cypher(variable + "." + fieldDefinition.propertyName().quote())
219+
} else {
220+
projectNeo4jObjectType(variable, field)
221+
}
210222
} else {
211223
projectRelationship(variable, field, fieldDefinition, type, env, variableSuffix)
212224
}
@@ -230,7 +242,7 @@ open class ProjectionBase {
230242
.filterIsInstance<Field>()
231243
.map {
232244
val value = when (it.name) {
233-
NEO4j_FORMATTED_PROPERTY_KEY -> "$variable.${field.name}"
245+
NEO4j_FORMATTED_PROPERTY_KEY -> "toString($variable.${field.name})"
234246
else -> "$variable.${field.name}.${it.name}"
235247
}
236248
"${it.name}: $value"
@@ -266,7 +278,7 @@ open class ProjectionBase {
266278
val fragmentType = env.graphQLSchema.getType(fragmentTypeName) as? GraphQLFieldsContainer ?: return emptyList()
267279
// these are the nested fields of the fragment
268280
// it could be that we have to adapt the variable name too, and perhaps add some kind of rename
269-
return projectSelectionSet(variable, selectionSet, fragmentType, env, variableSuffix)
281+
return projectSelection(variable, selectionSet.selections, fragmentType, env, variableSuffix)
270282
}
271283

272284

@@ -336,9 +348,27 @@ open class ProjectionBase {
336348
val relPattern = if (isRelFromType) "$childVariable:${relInfo.relType}" else ":${relInfo.relType}"
337349

338350
val where = where(childVariable, fieldDefinition, nodeType, propertyArguments(field), field)
339-
val fieldProjection = projectFields(childVariable, field, nodeType, env, variableSuffix)
340351

341-
val comprehension = "[($variable)$inArrow-[$relPattern]-$outArrow($endNodePattern)${where.query} | ${fieldProjection.query}]"
352+
val orderBy = getOrderByArgs(field.arguments)
353+
val sortByNeo4jTypeFields = orderBy
354+
.filter { (property, _) -> nodeType.getFieldDefinition(property)?.isNeo4jType() == true }
355+
.map { (property, _) -> property }
356+
.toSet()
357+
358+
val fieldProjection = projectFields(childVariable, field, nodeType, env, variableSuffix, sortByNeo4jTypeFields)
359+
var comprehension = "[($variable)$inArrow-[$relPattern]-$outArrow($endNodePattern)${where.query} | ${fieldProjection.query}]"
360+
if (orderBy.isNotEmpty()) {
361+
val sortArgs = orderBy.joinToString(", ", transform = { (property, direction) -> if (direction == Sort.ASC) "'^$property'" else "'$property'" })
362+
comprehension = "apoc.coll.sortMulti($comprehension, [$sortArgs])"
363+
if (sortByNeo4jTypeFields.isNotEmpty()) {
364+
val neo4jFiledSelection = field.selectionSet.selections
365+
.filter { selection -> sortByNeo4jTypeFields.contains((selection as? Field)?.name) }
366+
val deferredProjection = projectSelection("sortedElement", neo4jFiledSelection, nodeType, env, variableSuffix)
367+
.map { cypher -> cypher.query }
368+
.joinNonEmpty(", ")
369+
comprehension = "[sortedElement IN $comprehension | sortedElement { .*, $deferredProjection }]"
370+
}
371+
}
342372
val skipLimit = SkipLimit(childVariable, field.arguments)
343373
val slice = skipLimit.slice(fieldType.isList())
344374
return Cypher(comprehension + slice.query, (where.params + fieldProjection.params + slice.params))
@@ -392,4 +422,9 @@ open class ProjectionBase {
392422
}
393423
}
394424
}
425+
426+
enum class Sort {
427+
ASC,
428+
DESC
429+
}
395430
}

src/test/kotlin/org/neo4j/graphql/utils/GraphQLSchemaTestSuite.kt

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import org.junit.jupiter.api.DynamicTest
1515
import org.neo4j.graphql.DynamicProperties
1616
import org.neo4j.graphql.SchemaBuilder
1717
import org.neo4j.graphql.SchemaConfig
18+
import org.opentest4j.AssertionFailedError
1819
import java.io.File
1920
import java.util.regex.Pattern
2021
import javax.ws.rs.core.UriBuilder
@@ -62,31 +63,33 @@ class GraphQLSchemaTestSuite(fileName: String) : AsciiDocTestSuite() {
6263
private val ignore: Boolean) {
6364

6465
fun run() {
65-
println(title)
66+
var augmentedSchema: GraphQLSchema? = null
67+
var expectedSchema: GraphQLSchema? = null
6668
try {
67-
val augmentedSchema = SchemaBuilder.buildSchema(suite.schema, config)
69+
augmentedSchema = SchemaBuilder.buildSchema(suite.schema, config)
6870
val schemaParser = SchemaParser()
6971

70-
println("Augmented Schema:")
71-
println(suite.schemaPrinter.print(augmentedSchema))
72-
7372
val reg = schemaParser.parse(targetSchema)
7473
val schemaGenerator = SchemaGenerator()
7574
val runtimeWiring = RuntimeWiring.newRuntimeWiring()
7675
reg
7776
.getTypes(InterfaceTypeDefinition::class.java)
7877
.forEach { typeDefinition -> runtimeWiring.type(typeDefinition.name) { it.typeResolver { null } } }
79-
val expected = schemaGenerator.makeExecutableSchema(reg, runtimeWiring
78+
expectedSchema = schemaGenerator.makeExecutableSchema(reg, runtimeWiring
8079
.scalar(DynamicProperties.INSTANCE)
8180
.build())
8281

83-
diff(expected, augmentedSchema)
84-
diff(augmentedSchema, expected)
82+
diff(expectedSchema, augmentedSchema)
83+
diff(augmentedSchema, expectedSchema)
8584
} catch (e: Throwable) {
8685
if (ignore) {
8786
Assumptions.assumeFalse(true, e.message)
8887
} else {
89-
throw e
88+
throw AssertionFailedError("augmented schema differs for '$title'",
89+
expectedSchema?.let { suite.schemaPrinter.print(it) } ?: targetSchema,
90+
suite.schemaPrinter.print(augmentedSchema),
91+
e)
92+
9093
}
9194
}
9295
}
@@ -139,4 +142,4 @@ class GraphQLSchemaTestSuite(fileName: String) : AsciiDocTestSuite() {
139142
}
140143
}
141144
}
142-
}
145+
}

0 commit comments

Comments
 (0)