Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: fix get var expr when session var is hex literal (#23241) #23373

Merged
merged 4 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion expression/builtin_control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (s *testEvaluatorSuite) TestIfNull(c *C) {
{tm, nil, tm, false, false},
{nil, duration, duration, false, false},
{nil, types.NewDecFromFloatForTest(123.123), types.NewDecFromFloatForTest(123.123), false, false},
{nil, types.NewBinaryLiteralFromUint(0x01, -1), uint64(1), false, false},
{nil, types.NewBinaryLiteralFromUint(0x01, -1), "\x01", false, false},
{nil, types.Set{Value: 1, Name: "abc"}, "abc", false, false},
{nil, jsonInt.GetMysqlJSON(), jsonInt.GetMysqlJSON(), false, false},
{"abc", nil, "abc", false, false},
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ func (s *testEvaluatorSuite) TestHexFunc(c *C) {
{-1, false, false, "FFFFFFFFFFFFFFFF"},
{-12.3, false, false, "FFFFFFFFFFFFFFF4"},
{-12.8, false, false, "FFFFFFFFFFFFFFF3"},
{types.NewBinaryLiteralFromUint(0xC, -1), false, false, "C"},
{types.NewBinaryLiteralFromUint(0xC, -1), false, false, "0C"},
{0x12, false, false, "12"},
{nil, true, false, ""},
{errors.New("must err"), false, true, ""},
Expand Down
4 changes: 2 additions & 2 deletions expression/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ func (*testExpressionSuite) TestDeferredParamNotNull(c *C) {
c.Assert(mysql.TypeTimestamp, Equals, cstTime.GetType().Tp)
c.Assert(mysql.TypeDuration, Equals, cstDuration.GetType().Tp)
c.Assert(mysql.TypeBlob, Equals, cstBytes.GetType().Tp)
c.Assert(mysql.TypeBit, Equals, cstBinary.GetType().Tp)
c.Assert(mysql.TypeBit, Equals, cstBit.GetType().Tp)
c.Assert(mysql.TypeVarString, Equals, cstBinary.GetType().Tp)
c.Assert(mysql.TypeVarString, Equals, cstBit.GetType().Tp)
c.Assert(mysql.TypeFloat, Equals, cstFloat32.GetType().Tp)
c.Assert(mysql.TypeDouble, Equals, cstFloat64.GetType().Tp)
c.Assert(mysql.TypeEnum, Equals, cstEnum.GetType().Tp)
Expand Down
22 changes: 22 additions & 0 deletions planner/core/common_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,21 @@ type Execute struct {
Plan Plan
}

// Check if result of GetVar expr is BinaryLiteral
// Because GetVar use String to represent BinaryLiteral, here we need to convert string back to BinaryLiteral.
func isGetVarBinaryLiteral(sctx sessionctx.Context, expr expression.Expression) (res bool) {
scalarFunc, ok := expr.(*expression.ScalarFunction)
if ok && scalarFunc.FuncName.L == ast.GetVar {
name, isNull, err := scalarFunc.GetArgs()[0].EvalString(sctx, chunk.Row{})
if err != nil || isNull {
res = false
} else if dt, ok2 := sctx.GetSessionVars().Users[name]; ok2 {
res = (dt.Kind() == types.KindBinaryLiteral)
}
}
return res
}

// OptimizePreparedPlan optimizes the prepared statement.
func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema) error {
vars := sctx.GetSessionVars()
Expand Down Expand Up @@ -228,6 +243,13 @@ func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Cont
return err
}
param := prepared.Params[i].(*driver.ParamMarkerExpr)
if isGetVarBinaryLiteral(sctx, usingVar) {
binVal, convErr := val.ToBytes()
if convErr != nil {
return convErr
}
val.SetBinaryLiteral(types.BinaryLiteral(binVal))
}
param.Datum = val
param.InExecute = true
vars.PreparedParams = append(vars.PreparedParams, val)
Expand Down
3 changes: 2 additions & 1 deletion planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,8 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field
er.ctxStackAppend(expression.NewNull(), types.EmptyName)
return
}
if leftEt == types.ETInt {
containMut := expression.ContainMutableConst(er.sctx, args)
if !containMut && leftEt == types.ETInt {
for i := 1; i < len(args); i++ {
if c, ok := args[i].(*expression.Constant); ok {
var isExceptional bool
Expand Down
70 changes: 70 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3552,6 +3552,76 @@ func (s *testIntegrationSuite) TestIssue23839(c *C) {
tk.Exec("explain SELECT OUTR . col2 AS X FROM (SELECT INNR . col1 as col1, SUM( INNR . col2 ) as col2 FROM (SELECT INNR . `col_int_not_null` + 1 as col1, INNR . `pk` as col2 FROM BB AS INNR) AS INNR GROUP BY col1) AS OUTR2 INNER JOIN (SELECT INNR . col1 as col1, MAX( INNR . col2 ) as col2 FROM (SELECT INNR . `col_int_not_null` + 1 as col1, INNR . `pk` as col2 FROM BB AS INNR) AS INNR GROUP BY col1) AS OUTR ON OUTR2.col1 = OUTR.col1 GROUP BY OUTR . col1, OUTR2 . col1 HAVING X <> 'b'")
}

// #22949: test HexLiteral Used in GetVar expr
func (s *testIntegrationSuite) TestGetVarExprWithHexLiteral(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test;")
tk.MustExec("drop table if exists t1_no_idx;")
tk.MustExec("create table t1_no_idx(id int, col_bit bit(16));")
tk.MustExec("insert into t1_no_idx values(1, 0x3135);")
tk.MustExec("insert into t1_no_idx values(2, 0x0f);")

tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit = ?';")
tk.MustExec("set @a = 0x3135;")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1"))
tk.MustExec("set @a = 0x0F;")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2"))

// same test, but use IN expr
tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit in (?)';")
tk.MustExec("set @a = 0x3135;")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1"))
tk.MustExec("set @a = 0x0F;")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2"))

// same test, but use table with index on col_bit
tk.MustExec("drop table if exists t2_idx;")
tk.MustExec("create table t2_idx(id int, col_bit bit(16), key(col_bit));")
tk.MustExec("insert into t2_idx values(1, 0x3135);")
tk.MustExec("insert into t2_idx values(2, 0x0f);")

tk.MustExec("prepare stmt from 'select id from t2_idx where col_bit = ?';")
tk.MustExec("set @a = 0x3135;")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1"))
tk.MustExec("set @a = 0x0F;")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2"))

// same test, but use IN expr
tk.MustExec("prepare stmt from 'select id from t2_idx where col_bit in (?)';")
tk.MustExec("set @a = 0x3135;")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1"))
tk.MustExec("set @a = 0x0F;")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("2"))

// test col varchar with GetVar
tk.MustExec("drop table if exists t_varchar;")
tk.MustExec("create table t_varchar(id int, col_varchar varchar(100), key(col_varchar));")
tk.MustExec("insert into t_varchar values(1, '15');")
tk.MustExec("prepare stmt from 'select id from t_varchar where col_varchar = ?';")
tk.MustExec("set @a = 0x3135;")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1"))
}

// test BitLiteral used with GetVar
func (s *testIntegrationSuite) TestGetVarExprWithBitLiteral(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test;")
tk.MustExec("drop table if exists t1_no_idx;")
tk.MustExec("create table t1_no_idx(id int, col_bit bit(16));")
tk.MustExec("insert into t1_no_idx values(1, 0x3135);")
tk.MustExec("insert into t1_no_idx values(2, 0x0f);")

tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit = ?';")
// 0b11000100110101 is 0x3135
tk.MustExec("set @a = 0b11000100110101;")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1"))

// same test, but use IN expr
tk.MustExec("prepare stmt from 'select id from t1_no_idx where col_bit in (?)';")
tk.MustExec("set @a = 0b11000100110101;")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1"))
}

func (s *testIntegrationSuite) TestIssue26559(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
12 changes: 12 additions & 0 deletions planner/core/prepare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,18 @@ func (s *testPrepareSerialSuite) TestConstPropAndPPDWithCache(c *C) {
tk.MustQuery("execute stmt using @p0").Check(testkit.Rows(
"0",
))

// Need to check if contain mutable before RefineCompareConstant() in inToExpression().
// Otherwise may hit wrong plan.
tk.MustExec("use test;")
tk.MustExec("drop table if exists t1;")
tk.MustExec("create table t1(c1 tinyint unsigned);")
tk.MustExec("insert into t1 values(111);")
tk.MustExec("prepare stmt from 'select 1 from t1 where c1 in (?)';")
tk.MustExec("set @a = '1.1';")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows())
tk.MustExec("set @a = '111';")
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1"))
}

func (s *testPlanSerialSuite) TestPlanCacheUnionScan(c *C) {
Expand Down
2 changes: 1 addition & 1 deletion planner/core/testdata/ordered_result_mode_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"select b from t1 where b in (1, 2, 3, 4)",
"select * from t1 where a > 10 union all select * from t2 where b > 20",
"select * from t1 where a > 10 union distinct select * from t2 where b > 20",
"select row_number() over(partition by a) as row_no, sum(b) over(partition by a) as sum_b from t1",
"select sum(b) over(partition by a) as sum_b from t1",
"select min(a), max(b), sum(c) from t1 group by d",
"select min(a), max(b), sum(c) from t1 group by d having max(b) < 20",
"select case when a=1 then 'a1' when a=2 then 'a2' else 'ax' end from t1 "
Expand Down
11 changes: 5 additions & 6 deletions planner/core/testdata/ordered_result_mode_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,11 @@
},
{
"Plan": [
"Projection_10 10000.00 root Column#8, Column#7",
"└─Sort_11 10000.00 root test.t1.a, Column#7, Column#8",
" └─Window_13 10000.00 root row_number()->Column#8 over(partition by test.t1.a)",
" └─Window_14 10000.00 root sum(cast(test.t1.b, decimal(32,0) BINARY))->Column#7 over(partition by test.t1.a)",
" └─TableReader_17 10000.00 root data:TableFullScan_16",
" └─TableFullScan_16 10000.00 cop[tikv] table:t1 keep order:true, stats:pseudo"
"Projection_8 10000.00 root Column#6",
"└─Sort_9 10000.00 root test.t1.b, test.t1.a, Column#6",
" └─Window_11 10000.00 root sum(cast(test.t1.b, decimal(32,0) BINARY))->Column#6 over(partition by test.t1.a)",
" └─TableReader_13 10000.00 root data:TableFullScan_12",
" └─TableFullScan_12 10000.00 cop[tikv] table:t1 keep order:true, stats:pseudo"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions types/field_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ func DefaultTypeForValue(value interface{}, tp *FieldType, char string, collate
tp.Flag |= mysql.UnsignedFlag
SetBinChsClnFlag(tp)
case BinaryLiteral:
tp.Tp = mysql.TypeBit
tp.Flen = len(x) * 8
tp.Tp = mysql.TypeVarString
tp.Flen = len(x)
tp.Decimal = 0
SetBinChsClnFlag(tp)
tp.Flag &= ^mysql.BinaryFlag
Expand Down