Skip to content

Extract aggregate query fields for nested associations into dedicated by field #490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.GraphQLScalarType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -176,54 +176,58 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
aggregate.put(getAliasOrName(groupField), resultList);
});

aggregateField
.getSelectionSet()
.getSelections()
.stream()
.filter(Field.class::isInstance)
.map(Field.class::cast)
.filter(it -> !Arrays.asList("count", "group").contains(it.getName()))
.forEach(groupField -> {
var countField = getFields(groupField.getSelectionSet(), "count")
.stream()
.findFirst()
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField));

Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
.stream()
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
.toArray(Map.Entry[]::new);

if (groupings.length == 0) {
throw new GraphQLException("At least one field is required for aggregate group: " + groupField);
}

var resultList = queryFactory
.queryAggregateGroupByAssociationCount(
getAliasOrName(countField),
groupField.getName(),
environment,
restrictedKeys,
groupings
)
.stream()
.peek(map ->
Stream
.of(groupings)
.forEach(group -> {
var value = map.get(group.getKey());

Optional
.ofNullable(value)
.map(Object::getClass)
.map(JavaScalars::of)
.map(GraphQLScalarType::getCoercing)
.ifPresent(coercing -> map.put(group.getKey(), coercing.serialize(value)));
})
)
.toList();

aggregate.put(getAliasOrName(groupField), resultList);
getSelectionField(aggregateField, "by")
.map(byField -> byField.getSelectionSet().getSelections().stream().map(Field.class::cast).toList())
.filter(Predicate.not(List::isEmpty))
.ifPresent(aggregateBySelections -> {
var aggregatesBy = new LinkedHashMap<>();
aggregate.put("by", aggregatesBy);

aggregateBySelections.forEach(groupField -> {
var countField = getFields(groupField.getSelectionSet(), "count")
.stream()
.findFirst()
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField)
);

Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
.stream()
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
.toArray(Map.Entry[]::new);

if (groupings.length == 0) {
throw new GraphQLException(
"At least one field is required for aggregate group: " + groupField
);
}

var resultList = queryFactory
.queryAggregateGroupByAssociationCount(
getAliasOrName(countField),
groupField.getName(),
environment,
restrictedKeys,
groupings
)
.stream()
.peek(map ->
Stream
.of(groupings)
.forEach(group -> {
var value = map.get(group.getKey());

Optional
.ofNullable(value)
.map(Object::getClass)
.map(JavaScalars::of)
.map(GraphQLScalarType::getCoercing)
.ifPresent(coercing -> map.put(group.getKey(), coercing.serialize(value)));
})
)
.toList();

aggregatesBy.put(getAliasOrName(groupField), resultList);
});
});

pagedResult.withAggregate(aggregate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,10 @@ private GraphQLFieldDefinition getAggregateFieldDefinition(EntityType<?> entityT
)
);

var aggregateByObjectType = newObject()
.name(selectTypeName.concat("AggregateBy"))
.description("%s aggregate query type groups by nested associations".formatted(selectTypeName));

entityType
.getAttributes()
.stream()
Expand All @@ -544,11 +548,11 @@ private GraphQLFieldDefinition getAggregateFieldDefinition(EntityType<?> entityT
.toList();

if (!fields.isEmpty()) {
aggregateObjectType.field(
aggregateByObjectType.field(
newFieldDefinition()
.name(association.getName())
.description(
"Aggregate %s query field definition for the associated %s entity".formatted(
"Aggregate by %s query field definition for the associated %s entity".formatted(
selectTypeName,
association.getName()
)
Expand Down Expand Up @@ -617,6 +621,15 @@ private GraphQLFieldDefinition getAggregateFieldDefinition(EntityType<?> entityT

aggregateObjectType.field(countFieldDefinition).field(groupFieldDefinition);

if (!aggregateByObjectType.build().getFieldDefinitions().isEmpty()) {
aggregateObjectType.field(
newFieldDefinition()
.name("by")
.description("Nested aggregate by query field for %s entity".formatted(selectTypeName))
.type(aggregateByObjectType)
);
}

var aggregateFieldDefinition = newFieldDefinition()
.name("aggregate")
.description("Aggregate data query field for %s entity".formatted(selectTypeName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,17 +602,19 @@ public void queryVariablesTaskNestedAggregateCountByNestedAssociation() {
# count by variables
variables: count
# Count by associated tasks
tasks: task {
by(field: status)
count
by {
tasks: task {
by(field: status)
count
}
}
}
}
}
""";

String expected =
"{TaskVariables={aggregate={variables=3, tasks=[{by=COMPLETED, count=1}, {by=CREATED, count=2}]}}}";
"{TaskVariables={aggregate={variables=3, by={tasks=[{by=COMPLETED, count=1}, {by=CREATED, count=2}]}}}}";

//when
ExecutionResult result = executor.execute(query);
Expand All @@ -636,17 +638,19 @@ public void queryVariablesTaskNestedAggregateCountByNestedAssociationAlias() {
# count by variables
variables: count
# Count by associated tasks
tasks: task {
status: by(field: status)
count
by {
tasks: task {
status: by(field: status)
count
}
}
}
}
}
""";

String expected =
"{TaskVariables={aggregate={variables=3, tasks=[{status=COMPLETED, count=1}, {status=CREATED, count=2}]}}}";
"{TaskVariables={aggregate={variables=3, by={tasks=[{status=COMPLETED, count=1}, {status=CREATED, count=2}]}}}}";

//when
ExecutionResult result = executor.execute(query);
Expand Down Expand Up @@ -674,22 +678,24 @@ public void queryVariablesTaskNestedAggregateCountByNestedAssociationMultipleAli
name: by(field: name)
count
}
groupByTaskStatus: task {
status: by(field: status)
count
}
# Count by associated tasks
groupByTaskAssignee: task {
assignee: by(field: assignee)
count
by {
groupByTaskStatus: task {
status: by(field: status)
count
}
# Count by associated tasks
groupByTaskAssignee: task {
assignee: by(field: assignee)
count
}
}
}
}
}
""";

String expected =
"{TaskVariables={aggregate={variables=3, groupByVariableName=[{name=variable1, count=1}, {name=variable5, count=2}], groupByTaskStatus=[{status=COMPLETED, count=1}, {status=CREATED, count=2}], groupByTaskAssignee=[{assignee=assignee, count=3}]}}}";
"{TaskVariables={aggregate={variables=3, groupByVariableName=[{name=variable1, count=1}, {name=variable5, count=2}], by={groupByTaskStatus=[{status=COMPLETED, count=1}, {status=CREATED, count=2}], groupByTaskAssignee=[{assignee=assignee, count=3}]}}}}";

//when
ExecutionResult result = executor.execute(query);
Expand Down
Loading