Skip to content

Commit

Permalink
expression: fix a bug that DML using caseWhen may cause schema change (
Browse files Browse the repository at this point in the history
…#19857)

Signed-off-by: wjhuang2016 <huangwenjun1997@gmail.com>
  • Loading branch information
wjhuang2016 authored Nov 24, 2020
1 parent ca247e8 commit b77a514
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
19 changes: 8 additions & 11 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,17 +414,14 @@ func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
// function.
// Note: If the `Tp` of argument is the same as the `Tp` of the
// aggregation function, it will not wrap cast function on it
// internally. The reason of the special handling for `Column` is
// that the `RetType` of `Column` refers to the `infoschema`, so we
// need to set a new variable for it to avoid modifying the
// definition in `infoschema`.
if col, ok := a.Args[i].(*expression.Column); ok {
col.RetType = types.NewFieldType(col.RetType.Tp)
// internally.
switch x := a.Args[i].(type) {
case *expression.Column:
x.RetType = a.RetTp
case *expression.ScalarFunction:
x.RetType = a.RetTp
case *expression.CorrelatedColumn:
x.RetType = a.RetTp
}
// originTp is used when the the `Tp` of column is TypeFloat32 while
// the type of the aggregation function is TypeFloat64.
originTp := a.Args[i].GetType().Tp
*(a.Args[i].GetType()) = *(a.RetTp)
a.Args[i].GetType().Tp = originTp
}
}
2 changes: 1 addition & 1 deletion expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
foldedExpr.GetType().Decimal = expr.GetType().Decimal
return foldedExpr, isDeferredConst
}
return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst
return foldedExpr, isDeferredConst
}
} else {
// for no-const, here should return directly, because the following branches are unknown to be run or not
Expand Down
14 changes: 14 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,19 @@ func (s *testIntegrationSuite) Test19654(c *C) {
tk.MustQuery("select /*+ inl_join(t2)*/ * from t1, t2 where t1.b=t2.b;").Check(testkit.Rows("a a"))
}

func (s *testIntegrationSuite) Test19387(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("USE test;")

tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(a decimal(16, 2));")
tk.MustExec("select sum(case when 1 then a end) from t group by a;")
res := tk.MustQuery("show create table t")
c.Assert(len(res.Rows()), Equals, 1)
str := res.Rows()[0][1].(string)
c.Assert(strings.Contains(str, "decimal(16,2)"), IsTrue)
}

func (s *testIntegrationSuite) TestFuncREPEAT(c *C) {
tk := testkit.NewTestKit(c, s.store)
defer s.cleanEnv(c)
Expand Down Expand Up @@ -3673,6 +3686,7 @@ func (s *testIntegrationSuite) TestAggregationBuiltin(c *C) {
defer s.cleanEnv(c)
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a decimal(7, 6))")
tk.MustExec("insert into t values(1.123456), (1.123456)")
result := tk.MustQuery("select avg(a) from t")
Expand Down

0 comments on commit b77a514

Please sign in to comment.