Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.sql.ast.expression.When;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.expression.subquery.InSubquery;
import org.opensearch.sql.ast.statement.Explain;
import org.opensearch.sql.ast.statement.Query;
import org.opensearch.sql.ast.statement.Statement;
Expand Down Expand Up @@ -314,6 +315,10 @@ public T visitExplain(Explain node, C context) {
return visitStatement(node, context);
}

public T visitInSubquery(InSubquery node, C context) {
return visitChildren(node, context);
}

public T visitPaginate(Paginate paginate, C context) {
return visitChildren(paginate, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression.subquery;

import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.UnresolvedPlan;

@Getter
@ToString
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class InSubquery extends UnresolvedExpression {
private final List<UnresolvedExpression> value;
private final UnresolvedPlan query;

@Override
public List<UnresolvedExpression> getChild() {
return value;
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitInSubquery(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class CalciteRelNodeVisitor extends AbstractNodeVisitor<RelNode, CalciteP
private final CalciteAggCallVisitor aggVisitor;

public CalciteRelNodeVisitor() {
this.rexVisitor = new CalciteRexNodeVisitor();
this.rexVisitor = new CalciteRexNodeVisitor(this);
this.aggVisitor = new CalciteAggCallVisitor(rexVisitor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.RequiredArgsConstructor;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
Expand All @@ -38,9 +40,14 @@
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.expression.subquery.InSubquery;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.calcite.utils.BuiltinFunctionUtils;
import org.opensearch.sql.exception.SemanticCheckException;

@RequiredArgsConstructor
public class CalciteRexNodeVisitor extends AbstractNodeVisitor<RexNode, CalcitePlanContext> {
private final CalciteRelNodeVisitor planVisitor;

public RexNode analyze(UnresolvedExpression unresolved, CalcitePlanContext context) {
return unresolved.accept(this, context);
Expand Down Expand Up @@ -150,11 +157,18 @@ public RexNode visitEqualTo(EqualTo node, CalcitePlanContext context) {
public RexNode visitQualifiedName(QualifiedName node, CalcitePlanContext context) {
if (context.isResolvingJoinCondition()) {
List<String> parts = node.getParts();
if (parts.size() == 1) { // Handle the case of `id = cid`
if (parts.size() == 1) {
// Handle the case of `id = cid`
try {
return context.relBuilder.field(2, 0, parts.get(0));
} catch (IllegalArgumentException i) {
return context.relBuilder.field(2, 1, parts.get(0));
// TODO what if there is join clause in InSubquery in join condition
// for subquery in join condition
return context.relBuilder.field(parts.get(0));
} catch (IllegalArgumentException e) {
try {
return context.relBuilder.field(2, 0, parts.get(0));
} catch (IllegalArgumentException ee) {
return context.relBuilder.field(2, 1, parts.get(0));
}
}
} else if (parts.size()
== 2) { // Handle the case of `t1.id = t2.id` or `alias1.id = alias2.id`
Expand Down Expand Up @@ -242,4 +256,29 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) {
return context.rexBuilder.makeCall(
BuiltinFunctionUtils.translate(node.getFuncName()), arguments);
}

@Override
public RexNode visitInSubquery(InSubquery node, CalcitePlanContext context) {
List<RexNode> nodes = node.getChild().stream().map(child -> analyze(child, context)).toList();
UnresolvedPlan subquery = node.getQuery();
RelNode subqueryRel = subquery.accept(planVisitor, context);
context.relBuilder.build();
try {
return context.relBuilder.in(subqueryRel, nodes);
// TODO
// The {@link org.apache.calcite.tools.RelBuilder#in(RexNode,java.util.function.Function)}
// only support one expression. Change to follow code after calcite fixed.
// return context.relBuilder.in(
// nodes.getFirst(),
// b -> {
// RelNode subqueryRel = subquery.accept(planVisitor, context);
// b.build();
// return subqueryRel;
// });
} catch (AssertionError e) {
throw new SemanticCheckException(
"The number of columns in the left hand side of an IN subquery does not match the number"
+ " of columns in the output of subquery");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ public static RelDataType convertExprTypeToRelDataType(ExprType fieldType, boole
return TYPE_FACTORY.createSqlType(SqlTypeName.BINARY, nullable);
} else if (fieldType.legacyTypeName().equalsIgnoreCase("timestamp")) {
return TYPE_FACTORY.createSqlType(SqlTypeName.TIMESTAMP, nullable);
} else if (fieldType.legacyTypeName().equalsIgnoreCase("date")) {
return TYPE_FACTORY.createSqlType(SqlTypeName.DATE, nullable);
} else if (fieldType.legacyTypeName().equalsIgnoreCase("time")) {
return TYPE_FACTORY.createSqlType(SqlTypeName.TIME, nullable);
} else if (fieldType.legacyTypeName().equalsIgnoreCase("geo_point")) {
return TYPE_FACTORY.createSqlType(SqlTypeName.GEOMETRY, nullable);
} else if (fieldType.legacyTypeName().equalsIgnoreCase("text")) {
Expand Down
Loading
Loading