Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.sql.ast.expression.When;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.expression.subquery.ExistsSubquery;
import org.opensearch.sql.ast.expression.subquery.InSubquery;
import org.opensearch.sql.ast.statement.Explain;
import org.opensearch.sql.ast.statement.Query;
Expand Down Expand Up @@ -346,4 +347,8 @@ public T visitLookup(Lookup node, C context) {
public T visitSubqueryAlias(SubqueryAlias node, C context) {
return visitChildren(node, context);
}

public T visitExistsSubquery(ExistsSubquery node, C context) {
return visitChildren(node, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression.subquery;

import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.common.utils.StringUtils;

@Getter
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class ExistsSubquery extends UnresolvedExpression {
private final UnresolvedPlan query;

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitExistsSubquery(this, context);
}

@Override
public List<UnresolvedExpression> getChild() {
return ImmutableList.of();
}

@Override
public String toString() {
return StringUtils.format("exists ( %s )", query);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.common.utils.StringUtils;

@Getter
@ToString
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class InSubquery extends UnresolvedExpression {
Expand All @@ -31,4 +30,9 @@ public List<UnresolvedExpression> getChild() {
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitInSubquery(this, context);
}

@Override
public String toString() {
return StringUtils.format("%s in ( %s )", value, query);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;

import java.sql.Connection;
import java.util.Optional;
import java.util.Stack;
import java.util.function.BiFunction;
import lombok.Getter;
import lombok.Setter;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.RelBuilder;
Expand All @@ -25,6 +28,8 @@ public class CalcitePlanContext {
public final ExtendedRexBuilder rexBuilder;

@Getter @Setter private boolean isResolvingJoinCondition = false;
@Getter @Setter private boolean isResolvingExistsSubquery = false;
private final Stack<RexCorrelVariable> correlVar = new Stack<>();

private CalcitePlanContext(FrameworkConfig config) {
this.config = config;
Expand All @@ -42,6 +47,26 @@ public RexNode resolveJoinCondition(
return result;
}

public Optional<RexCorrelVariable> popCorrelVar() {
if (!correlVar.empty()) {
return Optional.of(correlVar.pop());
} else {
return Optional.empty();
}
}

public void pushCorrelVar(RexCorrelVariable v) {
correlVar.push(v);
}

public Optional<RexCorrelVariable> peekCorrelVar() {
if (!correlVar.empty()) {
return Optional.of(correlVar.peek());
} else {
return Optional.empty();
}
}

public static CalcitePlanContext create(FrameworkConfig config) {
return new CalcitePlanContext(config);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
Expand All @@ -22,16 +23,22 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilder.AggCall;
import org.apache.calcite.util.Holder;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.AllFields;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Not;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.subquery.ExistsSubquery;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
Expand Down Expand Up @@ -84,11 +91,35 @@ private RelBuilder scan(RelOptTable tableSchema, CalcitePlanContext context) {
@Override
public RelNode visitFilter(Filter node, CalcitePlanContext context) {
visitChildren(node, context);
boolean containsExistsSubquery = containsExistsSubquery(node.getCondition());
final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
if (containsExistsSubquery) {
context.relBuilder.variable(v::set);
context.pushCorrelVar(v.get());
}
RexNode condition = rexVisitor.analyze(node.getCondition(), context);
context.relBuilder.filter(condition);
if (containsExistsSubquery) {
context.relBuilder.filter(ImmutableList.of(v.get().id), condition);
context.popCorrelVar();
} else {
context.relBuilder.filter(condition);
}
return context.relBuilder.peek();
}

private boolean containsExistsSubquery(Object condition) {
if (condition instanceof ExistsSubquery) {
return true;
}
if (condition instanceof Not n) {
return containsExistsSubquery(n.getExpression());
}
if (condition instanceof Compare c) {
return containsExistsSubquery(c.getLeft()) || containsExistsSubquery(c.getRight());
}
return false;
}

@Override
public RelNode visitProject(Project node, CalcitePlanContext context) {
visitChildren(node, context);
Expand Down Expand Up @@ -174,6 +205,23 @@ public RelNode visitEval(Eval node, CalcitePlanContext context) {
if (!overriding.isEmpty()) {
List<RexNode> toDrop = context.relBuilder.fields(overriding);
context.relBuilder.projectExcept(toDrop);

// the overriding field in Calcite will add a numeric suffix, for example:
// `| eval SAL = SAL + 1` creates a field SAL0 to replace SAL, so we rename it back to SAL,
// or query `| eval SAL=SAL + 1 | where exists [ source=DEPT | where emp.SAL=HISAL ]` fails.
List<String> newNames =
context.relBuilder.peek().getRowType().getFieldNames().stream()
.map(
cur -> {
String noNumericSuffix = cur.replaceAll("\\d", "");
if (overriding.contains(noNumericSuffix)) {
return noNumericSuffix;
} else {
return cur;
}
})
.toList();
context.relBuilder.rename(newNames);
}
return context.relBuilder.peek();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserUtil;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.DateString;
import org.apache.calcite.util.Holder;
import org.apache.calcite.util.TimeString;
import org.apache.calcite.util.TimestampString;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.And;
Expand All @@ -41,6 +44,7 @@
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.expression.subquery.ExistsSubquery;
import org.opensearch.sql.ast.expression.subquery.InSubquery;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.calcite.utils.BuiltinFunctionUtils;
Expand Down Expand Up @@ -156,40 +160,62 @@ public RexNode visitEqualTo(EqualTo node, CalcitePlanContext context) {

@Override
public RexNode visitQualifiedName(QualifiedName node, CalcitePlanContext context) {
// 1. resolve QualifiedName in join condition
if (context.isResolvingJoinCondition()) {
List<String> parts = node.getParts();
if (parts.size() == 1) {
// Handle the case of `id = cid`
// 1.1 Handle the case of `id = cid`
try {
return context.relBuilder.field(2, 0, parts.getFirst());
} catch (IllegalArgumentException ee) {
return context.relBuilder.field(2, 1, parts.getFirst());
}
} else if (parts.size() == 2) {
// Handle the case of `t1.id = t2.id` or `alias1.id = alias2.id`
// 1.2 Handle the case of `t1.id = t2.id` or `alias1.id = alias2.id`
return context.relBuilder.field(2, parts.get(0), parts.get(1));
} else if (parts.size() == 3) {
throw new UnsupportedOperationException("Unsupported qualified name: " + node);
}
}

// 2. resolve QualifiedName in non-join condition
String qualifiedName = node.toString();
List<String> currentFields = context.relBuilder.peek().getRowType().getFieldNames();
if (currentFields.contains(qualifiedName)) {
// 2.1 resolve QualifiedName from stack top
return context.relBuilder.field(qualifiedName);
} else if (node.getParts().size() == 2) {
// 2.2 resolve QualifiedName with an alias or table name
List<String> parts = node.getParts();
return context.relBuilder.field(parts.get(0), parts.get(1));
try {
return context.relBuilder.field(1, parts.get(0), parts.get(1));
} catch (IllegalArgumentException e) {
// 2.3 resolve QualifiedName with outer alias
return context
.peekCorrelVar()
.map(correlVar -> context.relBuilder.field(correlVar, parts.get(1)))
.orElseThrow(() -> e); // Re-throw the exception if no correlated variable exists
}
} else if (currentFields.stream().noneMatch(f -> f.startsWith(qualifiedName))) {
return context.relBuilder.field(qualifiedName);
// 2.4 try resolving combination of 2.1 and 2.3 to resolve rest cases
return context
.peekCorrelVar()
.map(correlVar -> context.relBuilder.field(correlVar, qualifiedName))
.orElseGet(() -> context.relBuilder.field(qualifiedName));
}
// Handle the overriding fields, for example, `eval SAL = SAL + 1` will delete the original SAL
// and add a SAL0
// 3. resolve overriding fields, for example, `eval SAL = SAL + 1` will delete the original SAL
// and add a SAL0. SAL0 in currentFields, but qualifiedName is SAL.
// TODO now we cannot handle the case using a overriding fields in subquery, for example
// source = EMP | eval DEPTNO = DEPTNO + 1 | where exists [ source = DEPT | where emp.DEPTNO =
// DEPTNO ]
Map<String, String> fieldMap =
currentFields.stream().collect(Collectors.toMap(s -> s.replaceAll("\\d", ""), s -> s));
if (fieldMap.containsKey(qualifiedName)) {
return context.relBuilder.field(fieldMap.get(qualifiedName));
} else {
return null;
throw new IllegalArgumentException(
String.format(
"field [%s] not found; input fields are: %s", qualifiedName, currentFields));
}
}

Expand Down Expand Up @@ -256,20 +282,8 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
@Override
public RexNode visitInSubquery(InSubquery node, CalcitePlanContext context) {
List<RexNode> nodes = node.getChild().stream().map(child -> analyze(child, context)).toList();
// clear and store the outer state
boolean isResolvingJoinConditionOuter = context.isResolvingJoinCondition();
if (isResolvingJoinConditionOuter) {
context.setResolvingJoinCondition(false);
}
UnresolvedPlan subquery = node.getQuery();

RelNode subqueryRel = subquery.accept(planVisitor, context);
// pop the inner plan
context.relBuilder.build();
// restore to the previous state
if (isResolvingJoinConditionOuter) {
context.setResolvingJoinCondition(true);
}
RelNode subqueryRel = resolveSubqueryPlan(subquery, false, context);
try {
return context.relBuilder.in(subqueryRel, nodes);
// TODO
Expand All @@ -288,4 +302,32 @@ public RexNode visitInSubquery(InSubquery node, CalcitePlanContext context) {
+ " of columns in the output of subquery");
}
}

@Override
public RexNode visitExistsSubquery(ExistsSubquery node, CalcitePlanContext context) {
final Holder<@Nullable RexCorrelVariable> v = Holder.empty();
return context.relBuilder.exists(
b -> {
UnresolvedPlan subquery = node.getQuery();
return resolveSubqueryPlan(subquery, true, context);
});
}

private RelNode resolveSubqueryPlan(
UnresolvedPlan subquery, boolean isExists, CalcitePlanContext context) {
// clear and store the outer state
boolean isResolvingJoinConditionOuter = context.isResolvingJoinCondition();
if (isResolvingJoinConditionOuter) {
context.setResolvingJoinCondition(false);
}
RelNode subqueryRel = subquery.accept(planVisitor, context);
// pop the inner plan
context.relBuilder.build();
// clear the exists subquery resolving state
// restore to the previous state
if (isResolvingJoinConditionOuter) {
context.setResolvingJoinCondition(true);
}
return subqueryRel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ public static ExprType convertRelDataTypeToExprType(RelDataType type) {
return FLOAT;
case DOUBLE:
return DOUBLE;
case CHAR:
case VARCHAR:
return STRING;
case BOOLEAN:
Expand Down
Loading
Loading