Skip to content

Commit 2206491

Browse files
committed
SQL: Refactor args verification of In & conditionals (elastic#40916)
Move verification of arguments for Conditional functions and IN from `Verifier` to the `resolveType()` method of the functions. (cherry picked from commit 241644a)
1 parent 8eef92f commit 2206491

File tree

11 files changed

+212
-253
lines changed

11 files changed

+212
-253
lines changed

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunction;
2727
import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunctionAttribute;
2828
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
29-
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalFunction;
30-
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
3129
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
3230
import org.elasticsearch.xpack.sql.plan.logical.Distinct;
3331
import org.elasticsearch.xpack.sql.plan.logical.Filter;
@@ -228,9 +226,6 @@ Collection<Failure> verify(LogicalPlan plan) {
228226

229227
Set<Failure> localFailures = new LinkedHashSet<>();
230228

231-
validateInExpression(p, localFailures);
232-
validateConditional(p, localFailures);
233-
234229
checkGroupingFunctionInGroupBy(p, localFailures);
235230
checkFilterOnAggs(p, localFailures);
236231
checkFilterOnGrouping(p, localFailures);
@@ -724,52 +719,4 @@ private static void checkNestedUsedInGroupByOrHaving(LogicalPlan p, Set<Failure>
724719
fail(nested.get(0), "HAVING isn't (yet) compatible with nested fields " + new AttributeSet(nested).names()));
725720
}
726721
}
727-
728-
private static void validateInExpression(LogicalPlan p, Set<Failure> localFailures) {
729-
p.forEachExpressions(e ->
730-
e.forEachUp((In in) -> {
731-
DataType dt = in.value().dataType();
732-
for (Expression value : in.list()) {
733-
if (areTypesCompatible(dt, value.dataType()) == false) {
734-
localFailures.add(fail(value, "expected data type [{}], value provided is of type [{}]",
735-
dt.typeName, value.dataType().typeName));
736-
return;
737-
}
738-
}
739-
},
740-
In.class));
741-
}
742-
743-
private static void validateConditional(LogicalPlan p, Set<Failure> localFailures) {
744-
p.forEachExpressions(e ->
745-
e.forEachUp((ConditionalFunction cf) -> {
746-
DataType dt = DataType.NULL;
747-
748-
for (Expression child : cf.children()) {
749-
if (dt == DataType.NULL) {
750-
if (Expressions.isNull(child) == false) {
751-
dt = child.dataType();
752-
}
753-
} else {
754-
if (areTypesCompatible(dt, child.dataType()) == false) {
755-
localFailures.add(fail(child, "expected data type [{}], value provided is of type [{}]",
756-
dt.typeName, child.dataType().typeName));
757-
return;
758-
}
759-
}
760-
}
761-
},
762-
ConditionalFunction.class));
763-
}
764-
765-
private static boolean areTypesCompatible(DataType left, DataType right) {
766-
if (left == right) {
767-
return true;
768-
} else {
769-
return
770-
(left == DataType.NULL || right == DataType.NULL) ||
771-
(left.isString() && right.isString()) ||
772-
(left.isNumeric() && right.isNumeric());
773-
}
774-
}
775722
}

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/FunctionRegistry.java

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ void addToMap(FunctionDefinition...functions) {
262262
for (String alias : f.aliases()) {
263263
Object old = batchMap.put(alias, f);
264264
if (old != null || defs.containsKey(alias)) {
265-
throw new IllegalArgumentException("alias [" + alias + "] is used by "
265+
throw new SqlIllegalArgumentException("alias [" + alias + "] is used by "
266266
+ "[" + (old != null ? old : defs.get(alias).name()) + "] and [" + f.name() + "]");
267267
}
268268
aliases.put(alias, f.name());
@@ -321,10 +321,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
321321
java.util.function.Function<Source, T> ctorRef, String... names) {
322322
FunctionBuilder builder = (source, children, distinct, cfg) -> {
323323
if (false == children.isEmpty()) {
324-
throw new IllegalArgumentException("expects no arguments");
324+
throw new SqlIllegalArgumentException("expects no arguments");
325325
}
326326
if (distinct) {
327-
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
327+
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
328328
}
329329
return ctorRef.apply(source);
330330
};
@@ -341,10 +341,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
341341
ConfigurationAwareFunctionBuilder<T> ctorRef, String... names) {
342342
FunctionBuilder builder = (source, children, distinct, cfg) -> {
343343
if (false == children.isEmpty()) {
344-
throw new IllegalArgumentException("expects no arguments");
344+
throw new SqlIllegalArgumentException("expects no arguments");
345345
}
346346
if (distinct) {
347-
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
347+
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
348348
}
349349
return ctorRef.build(source, cfg);
350350
};
@@ -365,10 +365,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
365365
UnaryConfigurationAwareFunctionBuilder<T> ctorRef, String... names) {
366366
FunctionBuilder builder = (source, children, distinct, cfg) -> {
367367
if (children.size() > 1) {
368-
throw new IllegalArgumentException("expects exactly one argument");
368+
throw new SqlIllegalArgumentException("expects exactly one argument");
369369
}
370370
if (distinct) {
371-
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
371+
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
372372
}
373373
Expression ex = children.size() == 1 ? children.get(0) : null;
374374
return ctorRef.build(source, ex, cfg);
@@ -390,10 +390,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
390390
BiFunction<Source, Expression, T> ctorRef, String... names) {
391391
FunctionBuilder builder = (source, children, distinct, cfg) -> {
392392
if (children.size() != 1) {
393-
throw new IllegalArgumentException("expects exactly one argument");
393+
throw new SqlIllegalArgumentException("expects exactly one argument");
394394
}
395395
if (distinct) {
396-
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
396+
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
397397
}
398398
return ctorRef.apply(source, children.get(0));
399399
};
@@ -409,7 +409,7 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
409409
MultiFunctionBuilder<T> ctorRef, String... names) {
410410
FunctionBuilder builder = (source, children, distinct, cfg) -> {
411411
if (distinct) {
412-
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
412+
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
413413
}
414414
return ctorRef.build(source, children);
415415
};
@@ -429,7 +429,7 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
429429
DistinctAwareUnaryFunctionBuilder<T> ctorRef, String... names) {
430430
FunctionBuilder builder = (source, children, distinct, cfg) -> {
431431
if (children.size() != 1) {
432-
throw new IllegalArgumentException("expects exactly one argument");
432+
throw new SqlIllegalArgumentException("expects exactly one argument");
433433
}
434434
return ctorRef.build(source, children.get(0), distinct);
435435
};
@@ -449,10 +449,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
449449
DatetimeUnaryFunctionBuilder<T> ctorRef, String... names) {
450450
FunctionBuilder builder = (source, children, distinct, cfg) -> {
451451
if (children.size() != 1) {
452-
throw new IllegalArgumentException("expects exactly one argument");
452+
throw new SqlIllegalArgumentException("expects exactly one argument");
453453
}
454454
if (distinct) {
455-
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
455+
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
456456
}
457457
return ctorRef.build(source, children.get(0), cfg.zoneId());
458458
};
@@ -471,10 +471,10 @@ interface DatetimeUnaryFunctionBuilder<T> {
471471
static <T extends Function> FunctionDefinition def(Class<T> function, DatetimeBinaryFunctionBuilder<T> ctorRef, String... names) {
472472
FunctionBuilder builder = (source, children, distinct, cfg) -> {
473473
if (children.size() != 2) {
474-
throw new IllegalArgumentException("expects exactly two arguments");
474+
throw new SqlIllegalArgumentException("expects exactly two arguments");
475475
}
476476
if (distinct) {
477-
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
477+
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
478478
}
479479
return ctorRef.build(source, children.get(0), children.get(1), cfg.zoneId());
480480
};
@@ -496,13 +496,13 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
496496
boolean isBinaryOptionalParamFunction = function.isAssignableFrom(Round.class) || function.isAssignableFrom(Truncate.class)
497497
|| TopHits.class.isAssignableFrom(function);
498498
if (isBinaryOptionalParamFunction && (children.size() > 2 || children.size() < 1)) {
499-
throw new IllegalArgumentException("expects one or two arguments");
499+
throw new SqlIllegalArgumentException("expects one or two arguments");
500500
} else if (!isBinaryOptionalParamFunction && children.size() != 2) {
501-
throw new IllegalArgumentException("expects exactly two arguments");
501+
throw new SqlIllegalArgumentException("expects exactly two arguments");
502502
}
503503

504504
if (distinct) {
505-
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
505+
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
506506
}
507507
return ctorRef.build(source, children.get(0), children.size() == 2 ? children.get(1) : null);
508508
};
@@ -527,7 +527,7 @@ private static FunctionDefinition def(Class<? extends Function> function, Functi
527527
FunctionDefinition.Builder realBuilder = (uf, distinct, cfg) -> {
528528
try {
529529
return builder.build(uf.source(), uf.children(), distinct, cfg);
530-
} catch (IllegalArgumentException e) {
530+
} catch (SqlIllegalArgumentException e) {
531531
throw new ParsingException(uf.source(), "error building [" + primaryName + "]: " + e.getMessage(), e);
532532
}
533533
};
@@ -544,12 +544,12 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
544544
FunctionBuilder builder = (source, children, distinct, cfg) -> {
545545
boolean isLocateFunction = function.isAssignableFrom(Locate.class);
546546
if (isLocateFunction && (children.size() > 3 || children.size() < 2)) {
547-
throw new IllegalArgumentException("expects two or three arguments");
547+
throw new SqlIllegalArgumentException("expects two or three arguments");
548548
} else if (!isLocateFunction && children.size() != 3) {
549-
throw new IllegalArgumentException("expects exactly three arguments");
549+
throw new SqlIllegalArgumentException("expects exactly three arguments");
550550
}
551551
if (distinct) {
552-
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
552+
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
553553
}
554554
return ctorRef.build(source, children.get(0), children.get(1), children.size() == 3 ? children.get(2) : null);
555555
};
@@ -565,10 +565,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
565565
FourParametersFunctionBuilder<T> ctorRef, String... names) {
566566
FunctionBuilder builder = (source, children, distinct, cfg) -> {
567567
if (children.size() != 4) {
568-
throw new IllegalArgumentException("expects exactly four arguments");
568+
throw new SqlIllegalArgumentException("expects exactly four arguments");
569569
}
570570
if (distinct) {
571-
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
571+
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
572572
}
573573
return ctorRef.build(source, children.get(0), children.get(1), children.get(2), children.get(3));
574574
};

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/ArbitraryConditionalFunction.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
1414
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalProcessor.ConditionalOperation;
1515
import org.elasticsearch.xpack.sql.tree.Source;
16-
import org.elasticsearch.xpack.sql.type.DataTypeConversion;
1716

1817
import java.util.ArrayList;
1918
import java.util.List;
@@ -33,14 +32,6 @@ public abstract class ArbitraryConditionalFunction extends ConditionalFunction {
3332
this.operation = operation;
3433
}
3534

36-
@Override
37-
protected TypeResolution resolveType() {
38-
for (Expression e : children()) {
39-
dataType = DataTypeConversion.commonType(dataType, e.dataType());
40-
}
41-
return TypeResolution.TYPE_RESOLVED;
42-
}
43-
4435
@Override
4536
protected Pipe makePipe() {
4637
return new ConditionalPipe(source(), this, Expressions.pipe(children()), operation);

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/ConditionalFunction.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
1313
import org.elasticsearch.xpack.sql.tree.Source;
1414
import org.elasticsearch.xpack.sql.type.DataType;
15+
import org.elasticsearch.xpack.sql.type.DataTypeConversion;
1516

1617
import java.util.List;
1718

19+
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
20+
import static org.elasticsearch.xpack.sql.type.DataTypes.areTypesCompatible;
21+
import static org.elasticsearch.xpack.sql.util.StringUtils.ordinal;
22+
1823
/**
1924
* Base class for conditional predicates.
2025
*/
@@ -36,6 +41,31 @@ public boolean foldable() {
3641
return Expressions.foldable(children());
3742
}
3843

44+
@Override
45+
protected TypeResolution resolveType() {
46+
DataType dt = DataType.NULL;
47+
48+
for (int i = 0; i < children().size(); i++) {
49+
Expression child = children().get(i);
50+
if (dt == DataType.NULL) {
51+
if (Expressions.isNull(child) == false) {
52+
dt = child.dataType();
53+
}
54+
} else {
55+
if (areTypesCompatible(dt, child.dataType()) == false) {
56+
return new TypeResolution(format(null, "{} argument of [{}] must be [{}], found value [{}] type [{}]",
57+
ordinal(i + 1),
58+
sourceText(),
59+
dt.typeName,
60+
Expressions.name(child),
61+
child.dataType().typeName));
62+
}
63+
}
64+
dataType = DataTypeConversion.commonType(dataType, child.dataType());
65+
}
66+
return TypeResolution.TYPE_RESOLVED;
67+
}
68+
3969
@Override
4070
public Nullability nullable() {
4171
return Nullability.UNKNOWN;

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/conditional/NullIf.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,6 @@ public Expression replaceChildren(List<Expression> newChildren) {
3939
return new NullIf(source(), newChildren.get(0), newChildren.get(1));
4040
}
4141

42-
@Override
43-
protected TypeResolution resolveType() {
44-
dataType = children().get(0).dataType();
45-
return TypeResolution.TYPE_RESOLVED;
46-
}
47-
4842
@Override
4943
public Object fold() {
5044
return NullIfProcessor.apply(children().get(0).fold(), children().get(1).fold());

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/In.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
2828
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;
29+
import static org.elasticsearch.xpack.sql.type.DataTypes.areTypesCompatible;
30+
import static org.elasticsearch.xpack.sql.util.StringUtils.ordinal;
2931

3032
public class In extends ScalarFunction {
3133

@@ -109,7 +111,7 @@ protected Pipe makePipe() {
109111
@Override
110112
protected TypeResolution resolveType() {
111113
TypeResolution resolution = TypeResolutions.isExact(value, functionName(), Expressions.ParamOrdinal.DEFAULT);
112-
if (resolution != TypeResolution.TYPE_RESOLVED) {
114+
if (resolution.unresolved()) {
113115
return resolution;
114116
}
115117

@@ -120,6 +122,20 @@ protected TypeResolution resolveType() {
120122
name()));
121123
}
122124
}
125+
126+
DataType dt = value.dataType();
127+
for (int i = 0; i < list.size(); i++) {
128+
Expression listValue = list.get(i);
129+
if (areTypesCompatible(dt, listValue.dataType()) == false) {
130+
return new TypeResolution(format(null, "{} argument of [{}] must be [{}], found value [{}] type [{}]",
131+
ordinal(i + 1),
132+
sourceText(),
133+
dt.typeName,
134+
Expressions.name(listValue),
135+
listValue.dataType().typeName));
136+
}
137+
}
138+
123139
return super.resolveType();
124140
}
125141

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/type/DataTypes.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,4 +230,15 @@ public static Integer precision(DataType t) {
230230
}
231231
return t.displaySize;
232232
}
233+
234+
public static boolean areTypesCompatible(DataType left, DataType right) {
235+
if (left == right) {
236+
return true;
237+
} else {
238+
return
239+
(left == DataType.NULL || right == DataType.NULL) ||
240+
(left.isString() && right.isString()) ||
241+
(left.isNumeric() && right.isNumeric());
242+
}
243+
}
233244
}

0 commit comments

Comments
 (0)