Skip to content

Commit 4c36f74

Browse files
authored
Extract aggregate query fields for nested associations into dedicated by field (#490)
1 parent 56c5872 commit 4c36f74

File tree

3 files changed

+91
-68
lines changed

3 files changed

+91
-68
lines changed

schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/GraphQLJpaQueryDataFetcher.java

Lines changed: 53 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@
3535
import graphql.schema.DataFetchingEnvironment;
3636
import graphql.schema.GraphQLScalarType;
3737
import java.util.ArrayList;
38-
import java.util.Arrays;
3938
import java.util.LinkedHashMap;
4039
import java.util.List;
4140
import java.util.Map;
4241
import java.util.Optional;
42+
import java.util.function.Predicate;
4343
import java.util.stream.Stream;
4444
import org.slf4j.Logger;
4545
import org.slf4j.LoggerFactory;
@@ -176,54 +176,58 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
176176
aggregate.put(getAliasOrName(groupField), resultList);
177177
});
178178

179-
aggregateField
180-
.getSelectionSet()
181-
.getSelections()
182-
.stream()
183-
.filter(Field.class::isInstance)
184-
.map(Field.class::cast)
185-
.filter(it -> !Arrays.asList("count", "group").contains(it.getName()))
186-
.forEach(groupField -> {
187-
var countField = getFields(groupField.getSelectionSet(), "count")
188-
.stream()
189-
.findFirst()
190-
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField));
191-
192-
Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
193-
.stream()
194-
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
195-
.toArray(Map.Entry[]::new);
196-
197-
if (groupings.length == 0) {
198-
throw new GraphQLException("At least one field is required for aggregate group: " + groupField);
199-
}
200-
201-
var resultList = queryFactory
202-
.queryAggregateGroupByAssociationCount(
203-
getAliasOrName(countField),
204-
groupField.getName(),
205-
environment,
206-
restrictedKeys,
207-
groupings
208-
)
209-
.stream()
210-
.peek(map ->
211-
Stream
212-
.of(groupings)
213-
.forEach(group -> {
214-
var value = map.get(group.getKey());
215-
216-
Optional
217-
.ofNullable(value)
218-
.map(Object::getClass)
219-
.map(JavaScalars::of)
220-
.map(GraphQLScalarType::getCoercing)
221-
.ifPresent(coercing -> map.put(group.getKey(), coercing.serialize(value)));
222-
})
223-
)
224-
.toList();
225-
226-
aggregate.put(getAliasOrName(groupField), resultList);
179+
getSelectionField(aggregateField, "by")
180+
.map(byField -> byField.getSelectionSet().getSelections().stream().map(Field.class::cast).toList())
181+
.filter(Predicate.not(List::isEmpty))
182+
.ifPresent(aggregateBySelections -> {
183+
var aggregatesBy = new LinkedHashMap<>();
184+
aggregate.put("by", aggregatesBy);
185+
186+
aggregateBySelections.forEach(groupField -> {
187+
var countField = getFields(groupField.getSelectionSet(), "count")
188+
.stream()
189+
.findFirst()
190+
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField)
191+
);
192+
193+
Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
194+
.stream()
195+
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
196+
.toArray(Map.Entry[]::new);
197+
198+
if (groupings.length == 0) {
199+
throw new GraphQLException(
200+
"At least one field is required for aggregate group: " + groupField
201+
);
202+
}
203+
204+
var resultList = queryFactory
205+
.queryAggregateGroupByAssociationCount(
206+
getAliasOrName(countField),
207+
groupField.getName(),
208+
environment,
209+
restrictedKeys,
210+
groupings
211+
)
212+
.stream()
213+
.peek(map ->
214+
Stream
215+
.of(groupings)
216+
.forEach(group -> {
217+
var value = map.get(group.getKey());
218+
219+
Optional
220+
.ofNullable(value)
221+
.map(Object::getClass)
222+
.map(JavaScalars::of)
223+
.map(GraphQLScalarType::getCoercing)
224+
.ifPresent(coercing -> map.put(group.getKey(), coercing.serialize(value)));
225+
})
226+
)
227+
.toList();
228+
229+
aggregatesBy.put(getAliasOrName(groupField), resultList);
230+
});
227231
});
228232

229233
pagedResult.withAggregate(aggregate);

schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/GraphQLJpaSchemaBuilder.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,10 @@ private GraphQLFieldDefinition getAggregateFieldDefinition(EntityType<?> entityT
525525
)
526526
);
527527

528+
var aggregateByObjectType = newObject()
529+
.name(selectTypeName.concat("AggregateBy"))
530+
.description("%s aggregate query type groups by nested associations".formatted(selectTypeName));
531+
528532
entityType
529533
.getAttributes()
530534
.stream()
@@ -544,11 +548,11 @@ private GraphQLFieldDefinition getAggregateFieldDefinition(EntityType<?> entityT
544548
.toList();
545549

546550
if (!fields.isEmpty()) {
547-
aggregateObjectType.field(
551+
aggregateByObjectType.field(
548552
newFieldDefinition()
549553
.name(association.getName())
550554
.description(
551-
"Aggregate %s query field definition for the associated %s entity".formatted(
555+
"Aggregate by %s query field definition for the associated %s entity".formatted(
552556
selectTypeName,
553557
association.getName()
554558
)
@@ -617,6 +621,15 @@ private GraphQLFieldDefinition getAggregateFieldDefinition(EntityType<?> entityT
617621

618622
aggregateObjectType.field(countFieldDefinition).field(groupFieldDefinition);
619623

624+
if (!aggregateByObjectType.build().getFieldDefinitions().isEmpty()) {
625+
aggregateObjectType.field(
626+
newFieldDefinition()
627+
.name("by")
628+
.description("Nested aggregate by query field for %s entity".formatted(selectTypeName))
629+
.type(aggregateByObjectType)
630+
);
631+
}
632+
620633
var aggregateFieldDefinition = newFieldDefinition()
621634
.name("aggregate")
622635
.description("Aggregate data query field for %s entity".formatted(selectTypeName))

schema/src/test/java/com/introproventures/graphql/jpa/query/converter/GraphQLJpaQueryAggregateTests.java

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -602,17 +602,19 @@ public void queryVariablesTaskNestedAggregateCountByNestedAssociation() {
602602
# count by variables
603603
variables: count
604604
# Count by associated tasks
605-
tasks: task {
606-
by(field: status)
607-
count
605+
by {
606+
tasks: task {
607+
by(field: status)
608+
count
609+
}
608610
}
609611
}
610612
}
611613
}
612614
""";
613615

614616
String expected =
615-
"{TaskVariables={aggregate={variables=3, tasks=[{by=COMPLETED, count=1}, {by=CREATED, count=2}]}}}";
617+
"{TaskVariables={aggregate={variables=3, by={tasks=[{by=COMPLETED, count=1}, {by=CREATED, count=2}]}}}}";
616618

617619
//when
618620
ExecutionResult result = executor.execute(query);
@@ -636,17 +638,19 @@ public void queryVariablesTaskNestedAggregateCountByNestedAssociationAlias() {
636638
# count by variables
637639
variables: count
638640
# Count by associated tasks
639-
tasks: task {
640-
status: by(field: status)
641-
count
641+
by {
642+
tasks: task {
643+
status: by(field: status)
644+
count
645+
}
642646
}
643647
}
644648
}
645649
}
646650
""";
647651

648652
String expected =
649-
"{TaskVariables={aggregate={variables=3, tasks=[{status=COMPLETED, count=1}, {status=CREATED, count=2}]}}}";
653+
"{TaskVariables={aggregate={variables=3, by={tasks=[{status=COMPLETED, count=1}, {status=CREATED, count=2}]}}}}";
650654

651655
//when
652656
ExecutionResult result = executor.execute(query);
@@ -674,22 +678,24 @@ public void queryVariablesTaskNestedAggregateCountByNestedAssociationMultipleAli
674678
name: by(field: name)
675679
count
676680
}
677-
groupByTaskStatus: task {
678-
status: by(field: status)
679-
count
680-
}
681-
# Count by associated tasks
682-
groupByTaskAssignee: task {
683-
assignee: by(field: assignee)
684-
count
681+
by {
682+
groupByTaskStatus: task {
683+
status: by(field: status)
684+
count
685+
}
686+
# Count by associated tasks
687+
groupByTaskAssignee: task {
688+
assignee: by(field: assignee)
689+
count
690+
}
685691
}
686692
}
687693
}
688694
}
689695
""";
690696

691697
String expected =
692-
"{TaskVariables={aggregate={variables=3, groupByVariableName=[{name=variable1, count=1}, {name=variable5, count=2}], groupByTaskStatus=[{status=COMPLETED, count=1}, {status=CREATED, count=2}], groupByTaskAssignee=[{assignee=assignee, count=3}]}}}";
698+
"{TaskVariables={aggregate={variables=3, groupByVariableName=[{name=variable1, count=1}, {name=variable5, count=2}], by={groupByTaskStatus=[{status=COMPLETED, count=1}, {status=CREATED, count=2}], groupByTaskAssignee=[{assignee=assignee, count=3}]}}}}";
693699

694700
//when
695701
ExecutionResult result = executor.execute(query);

0 commit comments

Comments
 (0)