Skip to content

Commit

Permalink
[BugFix] Fix mishandled type null (#34985)
Browse files Browse the repository at this point in the history
Signed-off-by: liuyehcf <1559500551@qq.com>
  • Loading branch information
liuyehcf authored Nov 16, 2023
1 parent ce62989 commit 4a990fd
Show file tree
Hide file tree
Showing 12 changed files with 261 additions and 26 deletions.
5 changes: 3 additions & 2 deletions fe/fe-core/src/main/java/com/starrocks/analysis/Expr.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import com.starrocks.planner.FragmentNormalizer;
import com.starrocks.qe.ConnectContext;
import com.starrocks.server.GlobalStateMgr;
import com.starrocks.sql.analyzer.AnalyzerUtils;
import com.starrocks.sql.analyzer.AstToSQLBuilder;
import com.starrocks.sql.analyzer.ExpressionAnalyzer;
import com.starrocks.sql.analyzer.SemanticException;
Expand Down Expand Up @@ -826,8 +827,8 @@ final void treeToThriftHelper(TExpr container, ExprVisitor visitor) {

TExprNode msg = new TExprNode();

Preconditions.checkState(!type.isNull(), "NULL_TYPE is illegal in thrift stage");
Preconditions.checkState(!Objects.equal(Type.ARRAY_NULL, type), "Array<NULL_TYPE> is illegal in thrift stage");
Preconditions.checkState(java.util.Objects.equals(type, AnalyzerUtils.replaceNullType2Boolean(type)),
"NULL_TYPE is illegal in thrift stage");

msg.type = type.toThrift();
msg.num_children = children.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.google.common.collect.Lists;
import com.starrocks.analysis.ArithmeticExpr;
import com.starrocks.analysis.BinaryType;
import com.starrocks.catalog.ArrayType;
import com.starrocks.catalog.Function;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.MapType;
Expand Down Expand Up @@ -91,17 +90,14 @@ public ScalarOperator visitCall(CallOperator call, ScalarOperatorRewriteContext
Type type = fn.getArgs()[i];
ScalarOperator child = call.getChild(i);

//Cast from array(null), direct assignment type to avoid passing null_literal into be
if (type.isArrayType() && child.getType().isArrayType()
&& ((ArrayType) child.getType()).getItemType().isNull()) {
child.setType(type);
// For compatibility, decimal ArithmeticExpr(+-*/%) use Type::equals instead of Type::matchesType to
// determine whether to cast child of the ArithmeticExpr
if (needAdjustScale && type.isDecimalOfAnyVersion() && !type.equals(child.getType())) {
addCastChild(type, call, i);
continue;
}

// for compatibility, decimal ArithmeticExpr(+-*/%) use Type::equals instead of Type::matchesType to
// determine whether to cast child of the ArithmeticExpr
if ((needAdjustScale && type.isDecimalOfAnyVersion() && !type.equals(child.getType())) ||
!type.matchesType(child.getType())) {
if (!type.matchesType(child.getType())) {
addCastChild(type, call, i);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class ScalarOperatorToExpr {
Expand Down Expand Up @@ -138,8 +139,12 @@ public static class Formatter extends ScalarOperatorVisitor<Expr, FormatterConte
*/
private static void hackTypeNull(Expr expr) {
// For primitive types, this can be any legitimate type, for simplicity, we pick boolean.
Type type = AnalyzerUtils.replaceNullType2Boolean(expr.getType());
expr.setType(type);
Type previousType = expr.getType();
Type type = AnalyzerUtils.replaceNullType2Boolean(previousType);
// If actual type of expr is SlotRef, avoid change desc type if no hack happens.
if (!Objects.equals(previousType, type)) {
expr.setType(type);
}
}

@Override
Expand All @@ -158,9 +163,7 @@ public Expr visitVariableReference(ColumnRefOperator node, FormatterContext cont
return expr;
}

if (expr.getType().isNull()) {
hackTypeNull(expr);
}
hackTypeNull(expr);
return expr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void testArrayFnTransform() throws Exception {
@Test
public void testArrayFnWithLambdaExpr() throws Exception {
String sql = "select filter(array[], x -> true);";
assertPlanContains(sql, "array_filter([], array_map(<slot 2> -> TRUE, []))");
assertPlanContains(sql, "array_filter(CAST([] AS ARRAY<BOOLEAN>), array_map(<slot 2> -> TRUE, []))");

sql = "select filter(array[5, -6, NULL, 7], x -> x > 0);";
assertPlanContains(sql, " array_filter([5,-6,NULL,7], array_map(<slot 2> -> <slot 2> > 0, [5,-6,NULL,7]))");
Expand Down
29 changes: 24 additions & 5 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/ArrayTypeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,26 @@ public void testConcatArray() throws Exception {
plan = getFragmentPlan(sql);
assertContains(plan, "array_concat(CAST([1] AS ARRAY<VARCHAR>), CAST([2] AS ARRAY<VARCHAR>), " +
"CAST([1,2] AS ARRAY<VARCHAR>), ['a'], ['b'], CAST([1.1] AS ARRAY<VARCHAR>)");

sql = "with t0 as (\n" +
" select c1 from (values([])) as t(c1)\n" +
")\n" +
"select \n" +
"array_concat(c1, [1])\n" +
"from t0;";
plan = getFragmentPlan(sql);
getThriftPlan(sql); // Check null type handling
assertContains(plan, "<slot 3> : array_concat(CAST(2: c1 AS ARRAY<TINYINT>), [1])");

sql = "with t0 as (\n" +
" select c1 from (values([])) as t(c1)\n" +
")\n" +
"select \n" +
"array_concat(c1, [1])\n" +
"from t0;";
plan = getFragmentPlan(sql);
getThriftPlan(sql); // Check null type handling
assertContains(plan, "<slot 3> : array_concat(CAST(2: c1 AS ARRAY<TINYINT>), [1])");
}

@Test
Expand Down Expand Up @@ -320,13 +340,13 @@ public void testEmptyArray() throws Exception {
{
String sql = "select array_append([[1,2,3]], [])";
String plan = getFragmentPlan(sql);
assertContains(plan, "<slot 2> : array_append([[1,2,3]], [])");
assertContains(plan, "<slot 2> : array_append([[1,2,3]], CAST([] AS ARRAY<TINYINT>))");
}
{
String sql = "select array_append([[1,2,3]], [null])";
String plan = getFragmentPlan(sql);
assertContains(plan,
"<slot 2> : array_append([[1,2,3]], [NULL])");
"<slot 2> : array_append([[1,2,3]], CAST([NULL] AS ARRAY<TINYINT>))");
}
{
starRocksAssert.withTable("create table test_literal_array_insert_t0(" +
Expand Down Expand Up @@ -708,8 +728,7 @@ public void testArrayAgg() throws Exception {
public void testEmptyArrayOlap() throws Exception {
String sql = "select arrays_overlap([[],[]],[])";
String plan = getVerboseExplain(sql);
assertCContains(plan, "arrays_overlap[([[],[]], []); " +
"args: INVALID_TYPE,INVALID_TYPE; " +
"result: BOOLEAN; args nullable: true; result nullable: true]");
assertCContains(plan, "arrays_overlap[([[],[]], cast([] as ARRAY<ARRAY<BOOLEAN>>)); " +
"args: INVALID_TYPE,INVALID_TYPE; result: BOOLEAN; args nullable: true; result nullable: true]");
}
}
16 changes: 16 additions & 0 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/MapTypeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ public void testMapFunc() throws Exception { // get super common return type
String sql = "select map_concat(map{16865432442:3},map{3.323777777:'3'})";
String plan = getFragmentPlan(sql);
assertContains(plan, "MAP<DECIMAL128(28,9),VARCHAR>");

sql = "with t0 as (\n" +
" select c1 from (values(map())) as t(c1)\n" +
")\n" +
"select map_concat(map('a',1, 'b',2), c1)\n" +
"from t0;";
plan = getFragmentPlan(sql);
assertContains(plan, "map_concat(map{'a':1,'b':2}, CAST(2: c1 AS MAP<VARCHAR,TINYINT>))");

sql = "with t0 as (\n" +
" select c1 from (values(map())) as t(c1)\n" +
")\n" +
"select map_concat(c1, map('a',1, 'b',2))\n" +
"from t0;";
plan = getFragmentPlan(sql);
assertContains(plan, "map_concat(CAST(2: c1 AS MAP<VARCHAR,TINYINT>), map{'a':1,'b':2})");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ public void testLiteralArrayPredicates() throws Exception {
assertContains(plan, " 0:OlapScanNode\n" +
" table: pc0, rollup: pc0\n" +
" preAggregation: on\n" +
" Predicates: array_length([]) IS NOT NULL\n" +
" Predicates: array_length(CAST([] AS ARRAY<BOOLEAN>)) IS NOT NULL\n" +
" partitionsRatio=0/1, tabletsRatio=0/0\n" +
" tabletList=\n" +
" actualRows=0, avgRowSize=1.0\n" +
Expand Down Expand Up @@ -677,7 +677,7 @@ public void testCommonPathMerge() throws Exception {
assertContains(plan, " 0:OlapScanNode\n" +
" table: pc0, rollup: pc0\n" +
" preAggregation: on\n" +
" Predicates: array_length([]) IS NOT NULL\n" +
" Predicates: array_length(CAST([] AS ARRAY<BOOLEAN>)) IS NOT NULL\n" +
" partitionsRatio=0/1, tabletsRatio=0/0\n" +
" tabletList=\n" +
" actualRows=0, avgRowSize=3.0\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public void testLambda() throws Exception {
sql = "select array_filter((x,y) -> x<y, c3.d, c3.d) from test";
assertVerbosePlanContains(sql, "[/c3/d]");
sql = "select map_values(col_map), map_keys(col_map) from (select map_from_arrays([],[]) as col_map)A";
assertPlanContains(sql, "[], []");
assertPlanContains(sql, "map_from_arrays(5: cast, 5: cast)");
}

@Test
Expand Down
49 changes: 49 additions & 0 deletions test/sql/test_array/R/test_array
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,53 @@ insert into array_exprr SELECT generate_series, generate_series FROM TABLE(gener
select count([CAST(if(c2 is null, c1 + c2, 0) as DECIMAL128(38,0)) + if(c1 is null, c2 ,0)] is null) from array_exprr;
-- result:
13336
-- !result
-- name: testEmptyArray
with t0 as (
select c1 from (values([])) as t(c1)
)
select
array_concat(c1, [1])
from t0;
-- result:
[1]
-- !result
with t0 as (
select c1 from (values([])) as t(c1)
)
select
array_concat([1], c1)
from t0;
-- result:
[1]
-- !result
select array_concat(c1, [[]])
from (select c1 from (values([])) as t(c1)) t;
-- result:
[[]]
-- !result
select array_concat(c1, [[1]])
from (select c1 from (values([])) as t(c1)) t;
-- result:
[[1]]
-- !result
select array_concat(c1, [[1]])
from (select c1 from (values([[]])) as t(c1)) t;
-- result:
[[],[1]]
-- !result
select array_concat(c1, [map{'a':1}])
from (select c1 from (values([map()])) as t(c1)) t;
-- result:
[{},{"a":1}]
-- !result
select array_concat(c1, [map{'a':1}])
from (select c1 from (values([])) as t(c1)) t;
-- result:
[{"a":1}]
-- !result
select array_concat(c1, [named_struct('a', 1, 'b', 2, 'c', 3)])
from (select c1 from (values([])) as t(c1)) t;
-- result:
[{"a":1,"b":2,"c":3}]
-- !result
33 changes: 33 additions & 0 deletions test/sql/test_array/T/test_array
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,36 @@ CREATE TABLE array_exprr
insert into array_exprr SELECT generate_series, generate_series FROM TABLE(generate_series(1, 13336));

select count([CAST(if(c2 is null, c1 + c2, 0) as DECIMAL128(38,0)) + if(c1 is null, c2 ,0)] is null) from array_exprr;

-- name: testEmptyArray
with t0 as (
select c1 from (values([])) as t(c1)
)
select
array_concat(c1, [1])
from t0;

with t0 as (
select c1 from (values([])) as t(c1)
)
select
array_concat([1], c1)
from t0;

select array_concat(c1, [[]])
from (select c1 from (values([])) as t(c1)) t;

select array_concat(c1, [[1]])
from (select c1 from (values([])) as t(c1)) t;

select array_concat(c1, [[1]])
from (select c1 from (values([[]])) as t(c1)) t;

select array_concat(c1, [map{'a':1}])
from (select c1 from (values([map()])) as t(c1)) t;

select array_concat(c1, [map{'a':1}])
from (select c1 from (values([])) as t(c1)) t;

select array_concat(c1, [named_struct('a', 1, 'b', 2, 'c', 3)])
from (select c1 from (values([])) as t(c1)) t;
72 changes: 72 additions & 0 deletions test/sql/test_map/R/test_map
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,76 @@ select c2["key1"] from map_top_n;
2
12
7
-- !result
-- name: testEmptyMap
with t0 as (
select c1 from (values(map())) as t(c1)
)
select map_concat(map('a',1, 'b',2), c1)
from t0;
-- result:
{"a":1,"b":2}
-- !result
with t0 as (
select c1 from (values(map())) as t(c1)
)
select map_concat(c1, map('a',1, 'b',2))
from t0;
-- result:
{"a":1,"b":2}
-- !result
select map_concat(c1, map())
from (select c1 from (values(map())) as t(c1)) t;
-- result:
{}
-- !result
select map_concat(c1, map(1,2))
from (select c1 from (values(map())) as t(c1)) t;
-- result:
{1:2}
-- !result
select map_concat(c1, map(map(),map()))
from (select c1 from (values(map())) as t(c1)) t;
-- result:
{{}:{}}
-- !result
select map_concat(c1, map(map(),[]))
from (select c1 from (values(map())) as t(c1)) t;
-- result:
{{}:[]}
-- !result
select map_concat(c1, map([],map()))
from (select c1 from (values(map())) as t(c1)) t;
-- result:
{[]:{}}
-- !result
select map_concat(c1, map())
from (select c1 from (values(map([], []))) as t(c1)) t;
-- result:
{[]:[]}
-- !result
select map_concat(c1, map())
from (select c1 from (values(map([], [[]]))) as t(c1)) t;
-- result:
{[]:[[]]}
-- !result
select map_concat(c1, map([1], [2]))
from (select c1 from (values(map([], []))) as t(c1)) t;
-- result:
{[1]:[2],[]:[]}
-- !result
select map_concat(c1, map([1], [[2]]))
from (select c1 from (values(map([], [[]]))) as t(c1)) t;
-- result:
{[1]:[[2]],[]:[[]]}
-- !result
select map_concat(c1, map(named_struct('1',2), named_struct('A', [map()])))
from (select c1 from (values(map())) as t(c1)) t;
-- result:
{{"1":2}:{"A":[{}]}}
-- !result
select map_concat(c1, map(named_struct('1',2), named_struct('A', [map([],[[named_struct('C', [[map()]])]])])))
from (select c1 from (values(map())) as t(c1)) t;
-- result:
{{"1":2}:{"A":[{[]:[[{"C":[[{}]]}]]}]}}
-- !result
Loading

0 comments on commit 4a990fd

Please sign in to comment.