Skip to content

Commit ba08a65

Browse files
authored
Nested sorting on fields (#103)
* Nested sorting on fields (GH #3) This commit add the following features: * sorting on multiple properties * sorting on fields * augmentation-tests provides a diff on failure that can be viewed in IntelliJ * The formatted field for temporal is now a string rather than an object * Add test for #115 * adjustments after review
1 parent ee8c897 commit ba08a65

File tree

6 files changed

+301
-95
lines changed

6 files changed

+301
-95
lines changed

src/main/kotlin/org/neo4j/graphql/SchemaBuilder.kt

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ object SchemaBuilder {
7474

7575
val handler = getHandler(config)
7676

77-
var targetSchema = augmentSchema(sourceSchema, handler)
77+
var targetSchema = augmentSchema(sourceSchema, handler, config)
7878
targetSchema = addDataFetcher(targetSchema, dataFetchingInterceptor, handler)
7979
return targetSchema
8080
}
@@ -99,7 +99,7 @@ object SchemaBuilder {
9999
return handler
100100
}
101101

102-
private fun augmentSchema(sourceSchema: GraphQLSchema, handler: List<AugmentationHandler>): GraphQLSchema {
102+
private fun augmentSchema(sourceSchema: GraphQLSchema, handler: List<AugmentationHandler>, config: SchemaConfig): GraphQLSchema {
103103
val types = sourceSchema.typeMap.toMutableMap()
104104
val env = BuildingEnv(types)
105105

@@ -116,22 +116,26 @@ object SchemaBuilder {
116116
handler.forEach { h -> h.augmentType(type, env) }
117117
}
118118

119-
types.replaceAll { _, sourceType ->
119+
// since new types my be added to `types` we copy the map, to safely modify the entries and later add these
120+
// modified entries back to the `types`
121+
val adjustedTypes = types.toMutableMap()
122+
adjustedTypes.replaceAll { _, sourceType ->
120123
when {
121124
sourceType.name.startsWith("__") -> sourceType
122125
sourceType is GraphQLObjectType -> sourceType.transform { builder ->
123126
builder.clearFields().clearInterfaces()
124127
// to prevent duplicated types in schema
125128
sourceType.interfaces.forEach { builder.withInterface(GraphQLTypeReference(it.name)) }
126-
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f)) }
129+
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env)) }
127130
}
128131
sourceType is GraphQLInterfaceType -> sourceType.transform { builder ->
129132
builder.clearFields()
130-
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f)) }
133+
sourceType.fieldDefinitions.forEach { f -> builder.field(enhanceRelations(f, env)) }
131134
}
132135
else -> sourceType
133136
}
134137
}
138+
types.putAll(adjustedTypes)
135139

136140
return GraphQLSchema
137141
.newSchema(sourceSchema)
@@ -142,27 +146,29 @@ object SchemaBuilder {
142146
.build()
143147
}
144148

145-
private fun enhanceRelations(fd: GraphQLFieldDefinition): GraphQLFieldDefinition {
146-
return fd.transform {
149+
private fun enhanceRelations(fd: GraphQLFieldDefinition, env: BuildingEnv): GraphQLFieldDefinition {
150+
return fd.transform { fieldBuilder ->
147151
// to prevent duplicated types in schema
148-
it.type(fd.type.ref() as GraphQLOutputType)
152+
fieldBuilder.type(fd.type.ref() as GraphQLOutputType)
149153

150154
if (!fd.isRelationship() || !fd.type.isList()) {
151155
return@transform
152156
}
153157

154158
if (fd.getArgument(ProjectionBase.FIRST) == null) {
155-
it.argument { a -> a.name(ProjectionBase.FIRST).type(Scalars.GraphQLInt) }
159+
fieldBuilder.argument { a -> a.name(ProjectionBase.FIRST).type(Scalars.GraphQLInt) }
156160
}
157161
if (fd.getArgument(ProjectionBase.OFFSET) == null) {
158-
it.argument { a -> a.name(ProjectionBase.OFFSET).type(Scalars.GraphQLInt) }
162+
fieldBuilder.argument { a -> a.name(ProjectionBase.OFFSET).type(Scalars.GraphQLInt) }
163+
}
164+
if (fd.getArgument(ProjectionBase.ORDER_BY) == null && fd.type.isList()) {
165+
(fd.type.inner() as? GraphQLFieldsContainer)?.let { fieldType ->
166+
env.addOrdering(fieldType)?.let { orderingTypeName ->
167+
val orderType = GraphQLList(GraphQLNonNull(GraphQLTypeReference(orderingTypeName)))
168+
fieldBuilder.argument { a -> a.name(ProjectionBase.ORDER_BY).type(orderType) }
169+
}
170+
}
159171
}
160-
// TODO implement ordering
161-
// if (fd.getArgument(ProjectionBase.ORDER_BY) == null) {
162-
// val typeName = fd.type.name()!!
163-
// val orderingType = addOrdering(typeName, metaProvider.getNodeType(typeName)!!.fieldDefinitions().filter { it.type.isScalar() })
164-
// it.argument { a -> a.name(ProjectionBase.ORDER_BY).type(orderingType) }
165-
// }
166172
}
167173
}
168174

src/main/kotlin/org/neo4j/graphql/handler/QueryHandler.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class QueryHandler private constructor(
3434
.argument(input(OFFSET, Scalars.GraphQLInt))
3535
.type(GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLTypeReference(type.name)))))
3636
if (orderingTypeName != null) {
37-
builder.argument(input(ORDER_BY, GraphQLTypeReference(orderingTypeName)))
37+
val orderType = GraphQLList(GraphQLNonNull(GraphQLTypeReference(orderingTypeName)))
38+
builder.argument(input(ORDER_BY, orderType))
3839
}
3940
val def = builder.build()
4041
buildingEnv.addOperation(QUERY, def)

src/main/kotlin/org/neo4j/graphql/handler/projection/ProjectionBase.kt

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,30 @@ open class ProjectionBase {
1616
}
1717

1818
fun orderBy(variable: String, args: MutableList<Argument>): String {
19+
val values = getOrderByArgs(args)
20+
if (values.isEmpty()) {
21+
return ""
22+
}
23+
return " ORDER BY " + values.joinToString(", ", transform = { (property, direction) -> "$variable.$property $direction" })
24+
}
25+
26+
private fun getOrderByArgs(args: MutableList<Argument>): List<Pair<String, Sort>> {
1927
val arg = args.find { it.name == ORDER_BY }
20-
val values = arg?.value?.let { it ->
21-
when (it) {
22-
is ArrayValue -> it.values.map { it.toJavaValue().toString() }
23-
is EnumValue -> listOf(it.name)
24-
is StringValue -> listOf(it.value)
25-
else -> null
28+
return arg?.value
29+
?.let { it ->
30+
when (it) {
31+
is ArrayValue -> it.values.map { it.toJavaValue().toString() }
32+
is EnumValue -> listOf(it.name)
33+
is StringValue -> listOf(it.value)
34+
else -> null
35+
}
2636
}
27-
}
28-
@Suppress("SimplifiableCallChain")
29-
return if (values == null) ""
30-
else " ORDER BY " + values
31-
.map { it.split("_") }
32-
.map { "$variable.${it[0]} ${it[1].toUpperCase()}" }
33-
.joinToString(", ")
37+
?.map {
38+
val index = it.lastIndexOf('_')
39+
val property = it.substring(0, index)
40+
val direction = Sort.valueOf(it.substring(index + 1).toUpperCase())
41+
property to direction
42+
} ?: emptyList()
3443
}
3544

3645
fun where(variable: String, fieldDefinition: GraphQLFieldDefinition, type: GraphQLFieldsContainer, arguments: List<Argument>, field: Field): Cypher {
@@ -139,9 +148,8 @@ open class ProjectionBase {
139148
return predicates.values + defaults
140149
}
141150

142-
fun projectFields(variable: String, field: Field, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?): Cypher {
143-
val queries = projectSelectionSet(variable, field.selectionSet, nodeType, env, variableSuffix)
144-
151+
fun projectFields(variable: String, field: Field, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?, propertiesToSkipDeepProjection: Set<String> = emptySet()): Cypher {
152+
val queries = projectSelection(variable, field.selectionSet.selections, nodeType, env, variableSuffix, propertiesToSkipDeepProjection)
145153
@Suppress("SimplifiableCallChain")
146154
val projection = queries
147155
.map { it.query }
@@ -152,18 +160,18 @@ open class ProjectionBase {
152160
return Cypher("$variable $projection", params)
153161
}
154162

155-
private fun projectSelectionSet(variable: String, selectionSet: SelectionSet, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?): List<Cypher> {
163+
private fun projectSelection(variable: String, selection: List<Selection<*>>, nodeType: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?, propertiesToSkipDeepProjection: Set<String> = emptySet()): List<Cypher> {
156164
// TODO just render fragments on valid types (Labels) by using cypher like this:
157165
// apoc.map.mergeList([
158166
// a{.name},
159167
// CASE WHEN a:Location THEN a { .foo } ELSE {} END
160168
// ])
161169
var hasTypeName = false
162-
val projections = selectionSet.selections.flatMapTo(mutableListOf<Cypher>()) {
170+
val projections = selection.flatMapTo(mutableListOf<Cypher>()) {
163171
when (it) {
164172
is Field -> {
165173
hasTypeName = hasTypeName || (it.name == TYPE_NAME)
166-
listOf(projectField(variable, it, nodeType, env, variableSuffix))
174+
listOf(projectField(variable, it, nodeType, env, variableSuffix, propertiesToSkipDeepProjection))
167175
}
168176
is InlineFragment -> projectInlineFragment(variable, it, env, variableSuffix)
169177
is FragmentSpread -> projectNamedFragments(variable, it, env, variableSuffix)
@@ -180,7 +188,7 @@ open class ProjectionBase {
180188
return projections
181189
}
182190

183-
private fun projectField(variable: String, field: Field, type: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?): Cypher {
191+
private fun projectField(variable: String, field: Field, type: GraphQLFieldsContainer, env: DataFetchingEnvironment, variableSuffix: String?, propertiesToSkipDeepProjection: Set<String> = emptySet()): Cypher {
184192
if (field.name == TYPE_NAME) {
185193
return if (type.isRelationType()) {
186194
Cypher("${field.aliasOrName()}: '${type.name}'")
@@ -206,7 +214,14 @@ open class ProjectionBase {
206214
} ?: when {
207215
isObjectField -> {
208216
val patternComprehensions = if (fieldDefinition.isNeo4jType()) {
209-
projectNeo4jObjectType(variable, field)
217+
if (propertiesToSkipDeepProjection.contains(fieldDefinition.innerName())) {
218+
// if the property has an internal type like Date or DateTime and we want to compute on this
219+
// type (e.g sorting), we need to pass out the whole property and do the concrete projection
220+
// after the outer computation is done
221+
Cypher(variable + "." + fieldDefinition.propertyName().quote())
222+
} else {
223+
projectNeo4jObjectType(variable, field)
224+
}
210225
} else {
211226
projectRelationship(variable, field, fieldDefinition, type, env, variableSuffix)
212227
}
@@ -230,7 +245,7 @@ open class ProjectionBase {
230245
.filterIsInstance<Field>()
231246
.map {
232247
val value = when (it.name) {
233-
NEO4j_FORMATTED_PROPERTY_KEY -> "$variable.${field.name}"
248+
NEO4j_FORMATTED_PROPERTY_KEY -> "toString($variable.${field.name})"
234249
else -> "$variable.${field.name}.${it.name}"
235250
}
236251
"${it.name}: $value"
@@ -266,7 +281,7 @@ open class ProjectionBase {
266281
val fragmentType = env.graphQLSchema.getType(fragmentTypeName) as? GraphQLFieldsContainer ?: return emptyList()
267282
// these are the nested fields of the fragment
268283
// it could be that we have to adapt the variable name too, and perhaps add some kind of rename
269-
return projectSelectionSet(variable, selectionSet, fragmentType, env, variableSuffix)
284+
return projectSelection(variable, selectionSet.selections, fragmentType, env, variableSuffix)
270285
}
271286

272287

@@ -336,9 +351,27 @@ open class ProjectionBase {
336351
val relPattern = if (isRelFromType) "$childVariable:${relInfo.relType}" else ":${relInfo.relType}"
337352

338353
val where = where(childVariable, fieldDefinition, nodeType, propertyArguments(field), field)
339-
val fieldProjection = projectFields(childVariable, field, nodeType, env, variableSuffix)
340354

341-
val comprehension = "[($variable)$inArrow-[$relPattern]-$outArrow($endNodePattern)${where.query} | ${fieldProjection.query}]"
355+
val orderBy = getOrderByArgs(field.arguments)
356+
val sortByNeo4jTypeFields = orderBy
357+
.filter { (property, _) -> nodeType.getFieldDefinition(property)?.isNeo4jType() == true }
358+
.map { (property, _) -> property }
359+
.toSet()
360+
361+
val fieldProjection = projectFields(childVariable, field, nodeType, env, variableSuffix, sortByNeo4jTypeFields)
362+
var comprehension = "[($variable)$inArrow-[$relPattern]-$outArrow($endNodePattern)${where.query} | ${fieldProjection.query}]"
363+
if (orderBy.isNotEmpty()) {
364+
val sortArgs = orderBy.joinToString(", ", transform = { (property, direction) -> if (direction == Sort.ASC) "'^$property'" else "'$property'" })
365+
comprehension = "apoc.coll.sortMulti($comprehension, [$sortArgs])"
366+
if (sortByNeo4jTypeFields.isNotEmpty()) {
367+
val neo4jFieldSelection = field.selectionSet.selections
368+
.filter { selection -> sortByNeo4jTypeFields.contains((selection as? Field)?.name) }
369+
val deferredProjection = projectSelection("sortedElement", neo4jFieldSelection, nodeType, env, variableSuffix)
370+
.map { cypher -> cypher.query }
371+
.joinNonEmpty(", ")
372+
comprehension = "[sortedElement IN $comprehension | sortedElement { .*, $deferredProjection }]"
373+
}
374+
}
342375
val skipLimit = SkipLimit(childVariable, field.arguments)
343376
val slice = skipLimit.slice(fieldType.isList())
344377
return Cypher(comprehension + slice.query, (where.params + fieldProjection.params + slice.params))
@@ -392,4 +425,9 @@ open class ProjectionBase {
392425
}
393426
}
394427
}
428+
429+
enum class Sort {
430+
ASC,
431+
DESC
432+
}
395433
}

0 commit comments

Comments
 (0)