Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,38 @@ public Expression visitIn(In node, AnalysisContext context) {

private Expression visitIn(
UnresolvedExpression field, List<UnresolvedExpression> valueList, AnalysisContext context) {
if (valueList.size() == 1) {
return visitCompare(new Compare("=", field, valueList.get(0)), context);
} else if (valueList.size() > 1) {
return DSL.or(
visitCompare(new Compare("=", field, valueList.get(0)), context),
visitIn(field, valueList.subList(1, valueList.size()), context));
} else {
if (valueList.isEmpty()) {
throw new SemanticCheckException("Values in In clause should not be empty");
}

Expression[] expressions = new Expression[valueList.size()];

for (int i = 0; i < expressions.length; i++) {
expressions[i] = visitCompare(new Compare("=", field, valueList.get(i)), context);
}

return buildOrTree(expressions, 0, expressions.length);
}

/**
* `DSL.or` can only take two arguments. To represent large lists without massive recursion, we
* want to represent the expression as a balanced tree. This builds that tree from a node list.
*
* @param children The list of expressions to merge.
* @param start The starting position (inclusive) for the current combination step.
* @param end The ending position (exclusive) for the current combination step. If <= start,
* children[start] is returned.
* @return The final `DSL.or` expression.
*/
private Expression buildOrTree(Expression[] children, int start, int end) {
if (end - start <= 1) {
return children[start];
}
if (end - start == 2) {
return DSL.or(children[start], children[end - 1]);
}
int split = start + (end - start) / 2;
return DSL.or(buildOrTree(children, start, split), buildOrTree(children, split, end));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import static org.opensearch.sql.expression.DSL.ref;

import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
Expand Down Expand Up @@ -401,6 +402,17 @@ void visit_in() {
() -> analyze(AstDSL.in(field("integer_value"), Collections.emptyList())));
}

@Test
void visit_in_large_list() {
List<UnresolvedExpression> ints = new ArrayList<>();
for (int i = 0; i < 10000; i++) {
ints.add(intLiteral(i));
}

// Shouldn't crash
analyze(AstDSL.in(field("integer_value"), ints));
}

@Test
void multi_match_expression() {
assertAnalyzeEqual(
Expand Down
Loading