From b77a514ae9486043b010e6e466274c6153e80142 Mon Sep 17 00:00:00 2001 From: wjHuang Date: Tue, 24 Nov 2020 13:27:02 +0800 Subject: [PATCH] expression: fix a bug that DML using caseWhen may cause schema change (#19857) Signed-off-by: wjhuang2016 --- expression/aggregation/base_func.go | 19 ++++++++----------- expression/constant_fold.go | 2 +- expression/integration_test.go | 14 ++++++++++++++ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 146b3d36867cb..7f8dbdbcc63bf 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -414,17 +414,14 @@ func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) { // function. // Note: If the `Tp` of argument is the same as the `Tp` of the // aggregation function, it will not wrap cast function on it - // internally. The reason of the special handling for `Column` is - // that the `RetType` of `Column` refers to the `infoschema`, so we - // need to set a new variable for it to avoid modifying the - // definition in `infoschema`. - if col, ok := a.Args[i].(*expression.Column); ok { - col.RetType = types.NewFieldType(col.RetType.Tp) + // internally. + switch x := a.Args[i].(type) { + case *expression.Column: + x.RetType = a.RetTp + case *expression.ScalarFunction: + x.RetType = a.RetTp + case *expression.CorrelatedColumn: + x.RetType = a.RetTp } - // originTp is used when the the `Tp` of column is TypeFloat32 while - // the type of the aggregation function is TypeFloat64. - originTp := a.Args[i].GetType().Tp - *(a.Args[i].GetType()) = *(a.RetTp) - a.Args[i].GetType().Tp = originTp } } diff --git a/expression/constant_fold.go b/expression/constant_fold.go index 1b4dcaf00c3ad..c5810c8570387 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -122,7 +122,7 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { foldedExpr.GetType().Decimal = expr.GetType().Decimal return foldedExpr, isDeferredConst } - return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst + return foldedExpr, isDeferredConst } } else { // for no-const, here should return directly, because the following branches are unknown to be run or not diff --git a/expression/integration_test.go b/expression/integration_test.go index d3aa3e56e4cfd..f8106020b24f8 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -139,6 +139,19 @@ func (s *testIntegrationSuite) Test19654(c *C) { tk.MustQuery("select /*+ inl_join(t2)*/ * from t1, t2 where t1.b=t2.b;").Check(testkit.Rows("a a")) } +func (s *testIntegrationSuite) Test19387(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("USE test;") + + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(a decimal(16, 2));") + tk.MustExec("select sum(case when 1 then a end) from t group by a;") + res := tk.MustQuery("show create table t") + c.Assert(len(res.Rows()), Equals, 1) + str := res.Rows()[0][1].(string) + c.Assert(strings.Contains(str, "decimal(16,2)"), IsTrue) +} + func (s *testIntegrationSuite) TestFuncREPEAT(c *C) { tk := testkit.NewTestKit(c, s.store) defer s.cleanEnv(c) @@ -3673,6 +3686,7 @@ func (s *testIntegrationSuite) TestAggregationBuiltin(c *C) { defer s.cleanEnv(c) tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") + tk.MustExec("drop table if exists t") tk.MustExec("create table t(a decimal(7, 6))") tk.MustExec("insert into t values(1.123456), (1.123456)") result := tk.MustQuery("select avg(a) from t")