Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
*/
package org.apache.pinot.calcite.rel.rules;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
Expand All @@ -28,7 +34,14 @@
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Window;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.calcite.tools.RelBuilder;
Expand Down Expand Up @@ -122,4 +135,129 @@ public static String extractFunctionName(RexCall function) {
SqlKind funcSqlKind = function.getOperator().getKind();
return funcSqlKind == SqlKind.OTHER_FUNCTION ? function.getOperator().getName() : funcSqlKind.name();
}

public static class WindowUtils {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: lifted as is from the window exchange rule

// Supported window functions
// OTHER_FUNCTION supported are: BOOL_AND, BOOL_OR
private static final EnumSet<SqlKind> SUPPORTED_WINDOW_FUNCTION_KIND =
EnumSet.of(SqlKind.SUM, SqlKind.SUM0, SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, SqlKind.RANK,
SqlKind.DENSE_RANK, SqlKind.NTILE, SqlKind.LAG, SqlKind.LEAD, SqlKind.FIRST_VALUE, SqlKind.LAST_VALUE,
SqlKind.OTHER_FUNCTION);

public static void validateWindows(Window window) {
int numGroups = window.groups.size();
// For Phase 1 we only handle single window groups
Preconditions.checkState(numGroups == 1,
String.format("Currently only 1 window group is supported, query has %d groups", numGroups));

// Validate that only supported window aggregation functions are present
Window.Group windowGroup = window.groups.get(0);
validateWindowAggCallsSupported(windowGroup);

// Validate the frame
validateWindowFrames(windowGroup);
}

/**
* Replaces the reference to literal arguments in the window group with the actual literal values.
* NOTE: {@link Window} has a field called "constants" which contains the literal values. If the input reference is
* beyond the window input size, it is a reference to the constants.
*/
public static Window.Group updateLiteralArgumentsInWindowGroup(Window window) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add more comments about how does RexInputRef(x) find the element in constant list with more details? That's what I get by using copilot.

1. Understanding the Inputs

When Calcite constructs a Window relational node, it conceptually sees the input as:

  • All columns from the input relation (e.g., your table columns)
  • Followed by any literal constants needed by window functions (these are placed in the constants list in the Window node)

So, if your input row has N columns, and your window function needs K constants, the total "input" to the Window node is N + K.


2. How RexInputRef Works in This Context

  • RexInputRef(i) refers to the i-th field in the Window node's input.
  • If i < N, it points to the i-th field of the original input row (e.g., a table column).
  • If i >= N, it points to the (i - N)-th entry in the constants list.

Example:

Suppose your table has 2 columns: employee_id, salary.
Suppose your window function is:

SQL
LAG(salary, 2, 0) OVER (PARTITION BY dept ORDER BY hire_date)
  • The constants needed are 2 and 0, so constants = [2, 0].

Input to the Window node:

employee_id | salary | 2 | 0 -- | -- | -- | -- ... | ... | 2 | 0

So:

  • RexInputRef(0)  employee_id
  • RexInputRef(1)  salary
  • RexInputRef(2)  constants[0] = 2
  • RexInputRef(3)  constants[1] = 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is lifted as is from PinotWindowExchangeNodeInsertRule. GPT response is largely right: rex input ref is an integer reference to the input, but in window group it could exceed the field count of the input, in which case it is actually a reference to window.constants.

Window.Group oldWindowGroup = window.groups.get(0);
RelNode input = unboxRel(window.getInput());
int numInputFields = input.getRowType().getFieldCount();
List<RexNode> projects = input instanceof Project ? ((Project) input).getProjects() : null;

List<Window.RexWinAggCall> newAggCallWindow = new ArrayList<>(oldWindowGroup.aggCalls.size());
boolean windowChanged = false;
for (Window.RexWinAggCall oldAggCall : oldWindowGroup.aggCalls) {
boolean changed = false;
List<RexNode> oldOperands = oldAggCall.getOperands();
List<RexNode> newOperands = new ArrayList<>(oldOperands.size());
for (RexNode oldOperand : oldOperands) {
RexLiteral literal = getLiteral(oldOperand, numInputFields, window.constants, projects);
if (literal != null) {
newOperands.add(literal);
changed = true;
windowChanged = true;
} else {
newOperands.add(oldOperand);
}
}
if (changed) {
newAggCallWindow.add(
new Window.RexWinAggCall((SqlAggFunction) oldAggCall.getOperator(), oldAggCall.type, newOperands,
oldAggCall.ordinal, oldAggCall.distinct, oldAggCall.ignoreNulls));
} else {
newAggCallWindow.add(oldAggCall);
}
}

RexWindowBound lowerBound = oldWindowGroup.lowerBound;
RexNode offset = lowerBound.getOffset();
if (offset != null) {
RexLiteral literal = getLiteral(offset, numInputFields, window.constants, projects);
if (literal == null) {
throw new IllegalStateException(
"Could not read window lower bound literal value from window group: " + oldWindowGroup);
}
lowerBound = lowerBound.isPreceding() ? RexWindowBounds.preceding(literal) : RexWindowBounds.following(literal);
windowChanged = true;
}
RexWindowBound upperBound = oldWindowGroup.upperBound;
offset = upperBound.getOffset();
if (offset != null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment on when the offset is null; unbounded preceding / unbounded following / current row

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lifted as is from PinotWindowExchangeNodeInsertRule. I suppose offset is not null when functions like Lead and Lag are used.

RexLiteral literal = getLiteral(offset, numInputFields, window.constants, projects);
if (literal == null) {
throw new IllegalStateException(
"Could not read window upper bound literal value from window group: " + oldWindowGroup);
}
upperBound = upperBound.isFollowing() ? RexWindowBounds.following(literal) : RexWindowBounds.preceding(literal);
windowChanged = true;
}

return windowChanged ? new Window.Group(oldWindowGroup.keys, oldWindowGroup.isRows, lowerBound, upperBound,
oldWindowGroup.exclude, oldWindowGroup.orderKeys, newAggCallWindow) : oldWindowGroup;
}

private static void validateWindowAggCallsSupported(Window.Group windowGroup) {
for (Window.RexWinAggCall aggCall : windowGroup.aggCalls) {
SqlKind aggKind = aggCall.getKind();
Preconditions.checkState(SUPPORTED_WINDOW_FUNCTION_KIND.contains(aggKind),
String.format("Unsupported Window function kind: %s. Only aggregation functions are supported!", aggKind));
}
}

private static void validateWindowFrames(Window.Group windowGroup) {
RexWindowBound lowerBound = windowGroup.lowerBound;
RexWindowBound upperBound = windowGroup.upperBound;

boolean hasOffset = (lowerBound.isPreceding() && !lowerBound.isUnbounded()) || (upperBound.isFollowing()
&& !upperBound.isUnbounded());

if (!windowGroup.isRows) {
Preconditions.checkState(!hasOffset, "RANGE window frame with offset PRECEDING / FOLLOWING is not supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.pinot.apache.org/windows-functions The range frame clause with offset is claimed to be supported in the doc. Do I miss the content?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad.

}
}

@Nullable
private static RexLiteral getLiteral(RexNode rexNode, int numInputFields, ImmutableList<RexLiteral> constants,
@Nullable List<RexNode> projects) {
if (!(rexNode instanceof RexInputRef)) {
return null;
}
int index = ((RexInputRef) rexNode).getIndex();
if (index >= numInputFields) {
return constants.get(index - numInputFields);
}
if (projects != null) {
RexNode project = projects.get(index);
if (project instanceof RexLiteral) {
return (RexLiteral) project;
}
}
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@
*/
package org.apache.pinot.calcite.rel.rules;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
Expand All @@ -42,11 +39,7 @@
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilderFactory;
Expand Down Expand Up @@ -87,10 +80,10 @@ public boolean matches(RelOptRuleCall call) {
public void onMatch(RelOptRuleCall call) {
Window window = call.rel(0);
// Perform all validations
validateWindows(window);
PinotRuleUtils.WindowUtils.validateWindows(window);

RelNode input = window.getInput();
Window.Group windowGroup = updateLiteralArgumentsInWindowGroup(window);
Window.Group windowGroup = PinotRuleUtils.WindowUtils.updateLiteralArgumentsInWindowGroup(window);
Exchange exchange;
if (windowGroup.keys.isEmpty()) {
// Empty OVER()
Expand Down Expand Up @@ -147,122 +140,6 @@ public void onMatch(RelOptRuleCall call) {
List.of(windowGroup)));
}

/**
* Replaces the reference to literal arguments in the window group with the actual literal values.
* NOTE: {@link Window} has a field called "constants" which contains the literal values. If the input reference is
* beyond the window input size, it is a reference to the constants.
*/
private Window.Group updateLiteralArgumentsInWindowGroup(Window window) {
Window.Group oldWindowGroup = window.groups.get(0);
RelNode input = ((HepRelVertex) window.getInput()).getCurrentRel();
int numInputFields = input.getRowType().getFieldCount();
List<RexNode> projects = input instanceof Project ? ((Project) input).getProjects() : null;

List<Window.RexWinAggCall> newAggCallWindow = new ArrayList<>(oldWindowGroup.aggCalls.size());
boolean windowChanged = false;
for (Window.RexWinAggCall oldAggCall : oldWindowGroup.aggCalls) {
boolean changed = false;
List<RexNode> oldOperands = oldAggCall.getOperands();
List<RexNode> newOperands = new ArrayList<>(oldOperands.size());
for (RexNode oldOperand : oldOperands) {
RexLiteral literal = getLiteral(oldOperand, numInputFields, window.constants, projects);
if (literal != null) {
newOperands.add(literal);
changed = true;
windowChanged = true;
} else {
newOperands.add(oldOperand);
}
}
if (changed) {
newAggCallWindow.add(
new Window.RexWinAggCall((SqlAggFunction) oldAggCall.getOperator(), oldAggCall.type, newOperands,
oldAggCall.ordinal, oldAggCall.distinct, oldAggCall.ignoreNulls));
} else {
newAggCallWindow.add(oldAggCall);
}
}

RexWindowBound lowerBound = oldWindowGroup.lowerBound;
RexNode offset = lowerBound.getOffset();
if (offset != null) {
RexLiteral literal = getLiteral(offset, numInputFields, window.constants, projects);
if (literal == null) {
throw new IllegalStateException(
"Could not read window lower bound literal value from window group: " + oldWindowGroup);
}
lowerBound = lowerBound.isPreceding() ? RexWindowBounds.preceding(literal) : RexWindowBounds.following(literal);
windowChanged = true;
}
RexWindowBound upperBound = oldWindowGroup.upperBound;
offset = upperBound.getOffset();
if (offset != null) {
RexLiteral literal = getLiteral(offset, numInputFields, window.constants, projects);
if (literal == null) {
throw new IllegalStateException(
"Could not read window upper bound literal value from window group: " + oldWindowGroup);
}
upperBound = upperBound.isFollowing() ? RexWindowBounds.following(literal) : RexWindowBounds.preceding(literal);
windowChanged = true;
}

return windowChanged ? new Window.Group(oldWindowGroup.keys, oldWindowGroup.isRows, lowerBound, upperBound,
oldWindowGroup.exclude, oldWindowGroup.orderKeys, newAggCallWindow) : oldWindowGroup;
}

@Nullable
private RexLiteral getLiteral(RexNode rexNode, int numInputFields, ImmutableList<RexLiteral> constants,
@Nullable List<RexNode> projects) {
if (!(rexNode instanceof RexInputRef)) {
return null;
}
int index = ((RexInputRef) rexNode).getIndex();
if (index >= numInputFields) {
return constants.get(index - numInputFields);
}
if (projects != null) {
RexNode project = projects.get(index);
if (project instanceof RexLiteral) {
return (RexLiteral) project;
}
}
return null;
}

private void validateWindows(Window window) {
int numGroups = window.groups.size();
// For Phase 1 we only handle single window groups
Preconditions.checkState(numGroups == 1,
String.format("Currently only 1 window group is supported, query has %d groups", numGroups));

// Validate that only supported window aggregation functions are present
Window.Group windowGroup = window.groups.get(0);
validateWindowAggCallsSupported(windowGroup);

// Validate the frame
validateWindowFrames(windowGroup);
}

private void validateWindowAggCallsSupported(Window.Group windowGroup) {
for (Window.RexWinAggCall aggCall : windowGroup.aggCalls) {
SqlKind aggKind = aggCall.getKind();
Preconditions.checkState(SUPPORTED_WINDOW_FUNCTION_KIND.contains(aggKind),
String.format("Unsupported Window function kind: %s. Only aggregation functions are supported!", aggKind));
}
}

private void validateWindowFrames(Window.Group windowGroup) {
RexWindowBound lowerBound = windowGroup.lowerBound;
RexWindowBound upperBound = windowGroup.upperBound;

boolean hasOffset = (lowerBound.isPreceding() && !lowerBound.isUnbounded()) || (upperBound.isFollowing()
&& !upperBound.isUnbounded());

if (!windowGroup.isRows) {
Preconditions.checkState(!hasOffset, "RANGE window frame with offset PRECEDING / FOLLOWING is not supported");
}
}

private boolean isPartitionByOnlyQuery(Window.Group windowGroup) {
boolean isPartitionByOnly = false;
if (windowGroup.orderKeys.getKeys().isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@
import org.apache.calcite.rel.RelDistribution;
import org.apache.calcite.rel.RelDistributions;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.Window;
import org.apache.pinot.calcite.rel.hint.PinotHintOptions;
import org.apache.pinot.query.context.PhysicalPlannerContext;
import org.apache.pinot.query.planner.physical.v2.PRelNode;
import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalAggregate;
import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalAsOfJoin;
import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalJoin;
import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalProject;
import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalSort;
Expand Down Expand Up @@ -75,6 +77,8 @@ PRelNode assign(PRelNode pRelNode) {
return (PRelNode) assignSort((PhysicalSort) relNode);
} else if (relNode instanceof PhysicalJoin) {
return (PRelNode) assignJoin((PhysicalJoin) relNode);
} else if (relNode instanceof PhysicalAsOfJoin) {
return (PRelNode) assignJoin((PhysicalAsOfJoin) relNode);
} else if (relNode instanceof PhysicalAggregate) {
return (PRelNode) assignAggregate((PhysicalAggregate) relNode);
} else if (relNode instanceof PhysicalWindow) {
Expand Down Expand Up @@ -105,7 +109,7 @@ RelNode assignSort(PhysicalSort sort) {
* </p>
*/
@VisibleForTesting
RelNode assignJoin(PhysicalJoin join) {
RelNode assignJoin(Join join) {
// Case-1: Handle lookup joins.
if (PinotHintOptions.JoinHintOptions.useLookupJoinStrategy(join)) {
return assignLookupJoin(join);
Expand All @@ -121,8 +125,8 @@ RelNode assignJoin(PhysicalJoin join) {
"Always expect left and right keys to be same size. Found: %s and %s",
joinInfo.leftKeys, joinInfo.rightKeys);
// Case-3: Default case.
RelDistribution rightDistribution = joinInfo.isEqui() && !joinInfo.rightKeys.isEmpty()
? RelDistributions.hash(joinInfo.rightKeys) : RelDistributions.BROADCAST_DISTRIBUTED;
RelDistribution rightDistribution = !joinInfo.rightKeys.isEmpty() ? RelDistributions.hash(joinInfo.rightKeys)
: RelDistributions.BROADCAST_DISTRIBUTED;
RelDistribution leftDistribution;
if (joinInfo.leftKeys.isEmpty() || rightDistribution == RelDistributions.BROADCAST_DISTRIBUTED) {
leftDistribution = RelDistributions.RANDOM_DISTRIBUTED;
Expand Down Expand Up @@ -216,7 +220,7 @@ RelNode assignWindow(PhysicalWindow window) {
return window.copy(window.getTraitSet(), ImmutableList.of(input));
}

private RelNode assignLookupJoin(PhysicalJoin join) {
private RelNode assignLookupJoin(Join join) {
/*
* Lookup join expects right input to have project and table-scan nodes exactly. Moreover, lookup join is used
* with Dimension tables only. Given this, we expect the entire right input to be available in all workers
Expand Down
Loading
Loading