Skip to content

Commit

Permalink
Fix wrong group by column definition in druid connector
Browse files Browse the repository at this point in the history
Previous code may generate wrong group by name that end with suffix _number, and the dql could not be executed at druid side
  • Loading branch information
weidongduan37 authored and zhenxiao committed Sep 16, 2020
1 parent af273e1 commit df8228f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@

import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -300,7 +299,7 @@ public DruidQueryGeneratorContext visitAggregation(AggregationNode node, DruidQu

// 2nd pass
Map<VariableReferenceExpression, Selection> newSelections = new LinkedHashMap<>();
Set<VariableReferenceExpression> groupByColumns = new LinkedHashSet<>();
Map<VariableReferenceExpression, Selection> groupByColumns = new LinkedHashMap<>();
Set<VariableReferenceExpression> hiddenColumnSet = new HashSet<>(context.getHiddenColumnSet());
int aggregations = 0;
boolean groupByExists = false;
Expand All @@ -314,7 +313,7 @@ public DruidQueryGeneratorContext visitAggregation(AggregationNode node, DruidQu
Selection druidColumn = requireNonNull(context.getSelections().get(groupByInputColumn), "Group By column " + groupByInputColumn + " doesn't exist in input " + context.getSelections());

newSelections.put(outputColumn, new Selection(druidColumn.getDefinition(), druidColumn.getOrigin()));
groupByColumns.add(outputColumn);
groupByColumns.put(outputColumn, new Selection(druidColumn.getDefinition(), druidColumn.getOrigin()));
groupByExists = true;
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand All @@ -41,7 +40,7 @@
public class DruidQueryGeneratorContext
{
private final Map<VariableReferenceExpression, Selection> selections;
private final Set<VariableReferenceExpression> groupByColumns;
private final Map<VariableReferenceExpression, Selection> groupByColumns;
private final Set<VariableReferenceExpression> hiddenColumnSet;
private final Set<VariableReferenceExpression> variablesInAggregation;
private final Optional<String> from;
Expand Down Expand Up @@ -82,7 +81,7 @@ public String toString()
Optional.empty(),
OptionalLong.empty(),
0,
new HashSet<>(),
new LinkedHashMap<>(),
new HashSet<>(),
new HashSet<>(),
Optional.ofNullable(planNodeId));
Expand All @@ -94,7 +93,7 @@ private DruidQueryGeneratorContext(
Optional<String> filter,
OptionalLong limit,
int aggregations,
Set<VariableReferenceExpression> groupByColumns,
Map<VariableReferenceExpression, Selection> groupByColumns,
Set<VariableReferenceExpression> variablesInAggregation,
Set<VariableReferenceExpression> hiddenColumnSet,
Optional<PlanNodeId> tableScanNodeId)
Expand All @@ -104,7 +103,7 @@ private DruidQueryGeneratorContext(
this.filter = requireNonNull(filter, "filter is null");
this.limit = requireNonNull(limit, "limit is null");
this.aggregations = aggregations;
this.groupByColumns = new LinkedHashSet<>(requireNonNull(groupByColumns, "groupByColumns can't be null. It could be empty if not available"));
this.groupByColumns = new LinkedHashMap<>(requireNonNull(groupByColumns, "groupByColumns can't be null. It could be empty if not available"));
this.hiddenColumnSet = requireNonNull(hiddenColumnSet, "hidden column set is null");
this.variablesInAggregation = requireNonNull(variablesInAggregation, "variables in aggregation is null");
this.tableScanNodeId = requireNonNull(tableScanNodeId, "tableScanNodeId can't be null");
Expand Down Expand Up @@ -162,7 +161,7 @@ public DruidQueryGeneratorContext withLimit(long limit)

public DruidQueryGeneratorContext withAggregation(
Map<VariableReferenceExpression, Selection> newSelections,
Set<VariableReferenceExpression> newGroupByColumns,
Map<VariableReferenceExpression, Selection> newGroupByColumns,
int newAggregations,
Set<VariableReferenceExpression> newHiddenColumnSet)
{
Expand Down Expand Up @@ -287,9 +286,7 @@ public DruidQueryGenerator.GeneratedDql toQuery()
}

if (!groupByColumns.isEmpty()) {
String groupByExpression = groupByColumns.stream()
.map(expression -> selections.containsKey(expression) ? selections.get(expression).getEscapedDefinition() : expression.getName())
.collect(Collectors.joining(", "));
String groupByExpression = groupByColumns.entrySet().stream().map(v -> v.getValue().getEscapedDefinition()).collect(Collectors.joining(", "));
query = query + " GROUP BY " + groupByExpression;
pushdown = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,17 @@ public void testDistinctCountPushdown()
"SELECT count ( distinct \"region.Id\") FROM \"realtimeOnly\"");
}

@Test
public void testGroupByPushdown()
{
PlanNode justScan = buildPlan(planBuilder -> tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch, city, fare));
testDQL(
planBuilder -> planBuilder.aggregation(
aggBuilder -> aggBuilder.source(justScan).singleGroupingSet(variable("city"), variable("region.id"), variable("secondssinceepoch"))
.addAggregation(variable("totalfare"), getRowExpression("sum(\"fare\")", defaultSessionHolder))),
"SELECT \"city\", \"region.Id\", \"secondsSinceEpoch\", sum(fare) FROM \"realtimeOnly\" GROUP BY \"city\", \"region.Id\", \"secondsSinceEpoch\"");
}

@Test
public void testDistinctCountGroupByPushdown()
{
Expand Down

0 comments on commit df8228f

Please sign in to comment.