Skip to content

ES|QL: Resolve Keep plan added to FORK branches #129754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,43 @@ public void testWithUnsupportedFieldsAndConflicts() {
assertTrue(e.getMessage().contains("Column [embedding] has conflicting data types"));
}

public void testValidationsAfterFork() {
var firstQuery = """
FROM test*
| FORK ( WHERE true )
( WHERE true )
| DROP _fork
| STATS a = count_distinct(embedding)
""";

var e = expectThrows(VerificationException.class, () -> run(firstQuery));
assertTrue(
e.getMessage().contains("[count_distinct(embedding)] must be [any exact type except unsigned_long, _source, or counter types]")
);

var secondQuery = """
FROM test*
| FORK ( WHERE true )
( WHERE true )
| DROP _fork
| EVAL a = substring(1, 2, 3)
""";

e = expectThrows(VerificationException.class, () -> run(secondQuery));
assertTrue(e.getMessage().contains("first argument of [substring(1, 2, 3)] must be [string], found value [1] type [integer]"));

var thirdQuery = """
FROM test*
| FORK ( WHERE true )
( WHERE true )
| DROP _fork
| EVAL a = b + 2
""";

e = expectThrows(VerificationException.class, () -> run(thirdQuery));
assertTrue(e.getMessage().contains("Unknown column [b]"));
}

public void testWithEvalWithConflictingTypes() {
var query = """
FROM test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,10 +788,8 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
}

List<String> subPlanColumns = logicalPlan.output().stream().map(Attribute::name).toList();
// We need to add an explicit Keep even if the outputs align
// This is because at the moment the sub plans are executed and optimized separately and the output might change
// during optimizations. Once we add streaming we might not need to add a Keep when the outputs already align.
if (logicalPlan instanceof Keep == false || subPlanColumns.equals(forkColumns) == false) {
// We need to add an explicit EsqlProject to align the outputs.
if (logicalPlan instanceof Project == false || subPlanColumns.equals(forkColumns) == false) {
changed = true;
List<Attribute> newOutput = new ArrayList<>();
for (String attrName : forkColumns) {
Expand All @@ -801,7 +799,7 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
}
}
}
logicalPlan = new Keep(logicalPlan.source(), logicalPlan, newOutput);
logicalPlan = resolveKeep(new Keep(logicalPlan.source(), logicalPlan, newOutput), logicalPlan.output());
}

newSubPlans.add(logicalPlan);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.Fork;
import org.elasticsearch.xpack.esql.plan.logical.Insist;
import org.elasticsearch.xpack.esql.plan.logical.Keep;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Lookup;
Expand Down Expand Up @@ -3090,27 +3089,27 @@ public void testBasicFork() {
// fork branch 1
limit = as(subPlans.get(0), Limit.class);
assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
Keep keep = as(limit.child(), Keep.class);
List<String> keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(keptColumns, equalTo(expectedOutput));
Eval eval = as(keep.child(), Eval.class);
EsqlProject project = as(limit.child(), EsqlProject.class);
List<String> projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(projectColumns, equalTo(expectedOutput));
Eval eval = as(project.child(), Eval.class);
assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork1"))));
Filter filter = as(eval.child(), Filter.class);
assertThat(as(filter.condition(), GreaterThan.class).right(), equalTo(literal(1)));

filter = as(filter.child(), Filter.class);
assertThat(as(filter.condition(), Equals.class).right(), equalTo(string("Chris")));
EsqlProject project = as(filter.child(), EsqlProject.class);
project = as(filter.child(), EsqlProject.class);
var esRelation = as(project.child(), EsRelation.class);
assertThat(esRelation.indexPattern(), equalTo("test"));

// fork branch 2
limit = as(subPlans.get(1), Limit.class);
assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
keep = as(limit.child(), Keep.class);
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(keptColumns, equalTo(expectedOutput));
eval = as(keep.child(), Eval.class);
project = as(limit.child(), EsqlProject.class);
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(projectColumns, equalTo(expectedOutput));
eval = as(project.child(), Eval.class);
assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork2"))));
filter = as(eval.child(), Filter.class);
assertThat(as(filter.condition(), GreaterThan.class).right(), equalTo(literal(2)));
Expand All @@ -3124,10 +3123,10 @@ public void testBasicFork() {
// fork branch 3
limit = as(subPlans.get(2), Limit.class);
assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT));
keep = as(limit.child(), Keep.class);
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(keptColumns, equalTo(expectedOutput));
eval = as(keep.child(), Eval.class);
project = as(limit.child(), EsqlProject.class);
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(projectColumns, equalTo(expectedOutput));
eval = as(project.child(), Eval.class);
assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork3"))));
limit = as(eval.child(), Limit.class);
assertThat(as(limit.limit(), Literal.class).value(), equalTo(7));
Expand All @@ -3143,10 +3142,10 @@ public void testBasicFork() {
// fork branch 4
limit = as(subPlans.get(3), Limit.class);
assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
keep = as(limit.child(), Keep.class);
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(keptColumns, equalTo(expectedOutput));
eval = as(keep.child(), Eval.class);
project = as(limit.child(), EsqlProject.class);
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(projectColumns, equalTo(expectedOutput));
eval = as(project.child(), Eval.class);
assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork4"))));
orderBy = as(eval.child(), OrderBy.class);
filter = as(orderBy.child(), Filter.class);
Expand All @@ -3158,10 +3157,10 @@ public void testBasicFork() {
// fork branch 5
limit = as(subPlans.get(4), Limit.class);
assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT));
keep = as(limit.child(), Keep.class);
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(keptColumns, equalTo(expectedOutput));
eval = as(keep.child(), Eval.class);
project = as(limit.child(), EsqlProject.class);
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(projectColumns, equalTo(expectedOutput));
eval = as(project.child(), Eval.class);
assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork5"))));
limit = as(eval.child(), Limit.class);
assertThat(as(limit.limit(), Literal.class).value(), equalTo(9));
Expand Down Expand Up @@ -3193,11 +3192,11 @@ public void testForkBranchesWithDifferentSchemas() {
// fork branch 1
limit = as(subPlans.get(0), Limit.class);
assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT));
Keep keep = as(limit.child(), Keep.class);
List<String> keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(keptColumns, equalTo(expectedOutput));
EsqlProject project = as(limit.child(), EsqlProject.class);
List<String> projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(projectColumns, equalTo(expectedOutput));

Eval eval = as(keep.child(), Eval.class);
Eval eval = as(project.child(), Eval.class);
assertEquals(eval.fields().size(), 3);

Set<String> evalFieldNames = eval.fields().stream().map(a -> a.name()).collect(Collectors.toSet());
Expand All @@ -3215,7 +3214,7 @@ public void testForkBranchesWithDifferentSchemas() {
Filter filter = as(orderBy.child(), Filter.class);
assertThat(as(filter.condition(), GreaterThan.class).right(), equalTo(literal(3)));

EsqlProject project = as(filter.child(), EsqlProject.class);
project = as(filter.child(), EsqlProject.class);
filter = as(project.child(), Filter.class);
assertThat(as(filter.condition(), Equals.class).right(), equalTo(string("Chris")));
var esRelation = as(filter.child(), EsRelation.class);
Expand All @@ -3224,10 +3223,10 @@ public void testForkBranchesWithDifferentSchemas() {
// fork branch 2
limit = as(subPlans.get(1), Limit.class);
assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
keep = as(limit.child(), Keep.class);
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(keptColumns, equalTo(expectedOutput));
eval = as(keep.child(), Eval.class);
project = as(limit.child(), EsqlProject.class);
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(projectColumns, equalTo(expectedOutput));
eval = as(project.child(), Eval.class);
assertEquals(eval.fields().size(), 2);
evalFieldNames = eval.fields().stream().map(a -> a.name()).collect(Collectors.toSet());
assertThat(evalFieldNames, equalTo(Set.of("x", "y")));
Expand All @@ -3254,10 +3253,10 @@ public void testForkBranchesWithDifferentSchemas() {
// fork branch 3
limit = as(subPlans.get(2), Limit.class);
assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
keep = as(limit.child(), Keep.class);
keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(keptColumns, equalTo(expectedOutput));
eval = as(keep.child(), Eval.class);
project = as(limit.child(), EsqlProject.class);
projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList();
assertThat(projectColumns, equalTo(expectedOutput));
eval = as(project.child(), Eval.class);
assertEquals(eval.fields().size(), 2);
evalFieldNames = eval.fields().stream().map(a -> a.name()).collect(Collectors.toSet());
assertThat(evalFieldNames, equalTo(Set.of("emp_no", "first_name")));
Expand Down