Skip to content

Commit 3d2c261

Browse files
authored
Fix: Long IN-lists causes crash (#3660)
1 parent 92cb089 commit 3d2c261

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,38 @@ public Expression visitIn(In node, AnalysisContext context) {
292292

293293
private Expression visitIn(
294294
UnresolvedExpression field, List<UnresolvedExpression> valueList, AnalysisContext context) {
295-
if (valueList.size() == 1) {
296-
return visitCompare(new Compare("=", field, valueList.get(0)), context);
297-
} else if (valueList.size() > 1) {
298-
return DSL.or(
299-
visitCompare(new Compare("=", field, valueList.get(0)), context),
300-
visitIn(field, valueList.subList(1, valueList.size()), context));
301-
} else {
295+
if (valueList.isEmpty()) {
302296
throw new SemanticCheckException("Values in In clause should not be empty");
303297
}
298+
299+
Expression[] expressions = new Expression[valueList.size()];
300+
301+
for (int i = 0; i < expressions.length; i++) {
302+
expressions[i] = visitCompare(new Compare("=", field, valueList.get(i)), context);
303+
}
304+
305+
return buildOrTree(expressions, 0, expressions.length);
306+
}
307+
308+
/**
309+
* `DSL.or` can only take two arguments. To represent large lists without massive recursion, we
310+
* want to represent the expression as a balanced tree. This builds that tree from a node list.
311+
*
312+
* @param children The list of expressions to merge.
313+
* @param start The starting position (inclusive) for the current combination step.
314+
* @param end The ending position (exclusive) for the current combination step. If <= start,
315+
* children[start] is returned.
316+
* @return The final `DSL.or` expression.
317+
*/
318+
private Expression buildOrTree(Expression[] children, int start, int end) {
319+
if (end - start <= 1) {
320+
return children[start];
321+
}
322+
if (end - start == 2) {
323+
return DSL.or(children[start], children[end - 1]);
324+
}
325+
int split = start + (end - start) / 2;
326+
return DSL.or(buildOrTree(children, start, split), buildOrTree(children, split, end));
304327
}
305328

306329
@Override

core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import static org.opensearch.sql.expression.DSL.ref;
2828

2929
import com.google.common.collect.ImmutableMap;
30+
import java.util.ArrayList;
3031
import java.util.Collections;
3132
import java.util.LinkedHashMap;
3233
import java.util.List;
@@ -401,6 +402,17 @@ void visit_in() {
401402
() -> analyze(AstDSL.in(field("integer_value"), Collections.emptyList())));
402403
}
403404

405+
@Test
406+
void visit_in_large_list() {
407+
List<UnresolvedExpression> ints = new ArrayList<>();
408+
for (int i = 0; i < 10000; i++) {
409+
ints.add(intLiteral(i));
410+
}
411+
412+
// Shouldn't crash
413+
analyze(AstDSL.in(field("integer_value"), ints));
414+
}
415+
404416
@Test
405417
void multi_match_expression() {
406418
assertAnalyzeEqual(

0 commit comments

Comments
 (0)