Skip to content

Commit

Permalink
This is an automated cherry-pick of pingcap#40262
Browse files Browse the repository at this point in the history
Signed-off-by: ti-chi-bot <ti-community-prow-bot@tidb.io>
  • Loading branch information
djshow832 authored and ti-chi-bot committed Mar 28, 2023
1 parent 135aafd commit 63d0112
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 4 deletions.
33 changes: 33 additions & 0 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -1429,3 +1429,36 @@ func PropagateType(evalType types.EvalType, args ...Expression) {
}
}
}
<<<<<<< HEAD
=======

// Args2Expressions4Test converts these values to an expression list.
// This conversion is incomplete, so only use for test.
func Args2Expressions4Test(args ...interface{}) []Expression {
exprs := make([]Expression, len(args))
for i, v := range args {
d := types.NewDatum(v)
var ft *types.FieldType
switch d.Kind() {
case types.KindNull:
ft = types.NewFieldType(mysql.TypeNull)
case types.KindInt64:
ft = types.NewFieldType(mysql.TypeLong)
case types.KindUint64:
ft = types.NewFieldType(mysql.TypeLong)
ft.AddFlag(mysql.UnsignedFlag)
case types.KindFloat64:
ft = types.NewFieldType(mysql.TypeDouble)
case types.KindString:
ft = types.NewFieldType(mysql.TypeVarString)
case types.KindMysqlTime:
ft = types.NewFieldType(mysql.TypeTimestamp)
default:
exprs[i] = nil
continue
}
exprs[i] = &Constant{Value: d, RetType: ft}
}
return exprs
}
>>>>>>> 95f0dc547e9 (planner: support pushing down predicates to memory tables in prepared mode (#40262))
16 changes: 12 additions & 4 deletions planner/core/memtable_predicate_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ func (helper extractHelper) extractColInConsExpr(extractCols map[int64]*types.Fi
results := make([]types.Datum, 0, len(args[1:]))
for _, arg := range args[1:] {
constant, ok := arg.(*expression.Constant)
if !ok || constant.DeferredExpr != nil || constant.ParamMarker != nil {
if !ok || constant.DeferredExpr != nil {
return "", nil
}
results = append(results, constant.Value)
v := constant.Value
if constant.ParamMarker != nil {
v = constant.ParamMarker.GetUserVar()
}
results = append(results, v)
}
return name.ColName.L, results
}
Expand Down Expand Up @@ -117,10 +121,14 @@ func (helper extractHelper) extractColBinaryOpConsExpr(extractCols map[int64]*ty
// SELECT * FROM t1 WHERE c='rhs'
// SELECT * FROM t1 WHERE 'lhs'=c
constant, ok := args[1-colIdx].(*expression.Constant)
if !ok || constant.DeferredExpr != nil || constant.ParamMarker != nil {
if !ok || constant.DeferredExpr != nil {
return "", nil
}
return name.ColName.L, []types.Datum{constant.Value}
v := constant.Value
if constant.ParamMarker != nil {
v = constant.ParamMarker.GetUserVar()
}
return name.ColName.L, []types.Datum{v}
}

// extract the OR expression, e.g:
Expand Down
113 changes: 113 additions & 0 deletions planner/core/memtable_predicate_extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ import (

"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/planner"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/hint"
"github.com/pingcap/tidb/util/set"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -1742,3 +1746,112 @@ func TestTikvRegionStatusExtractor(t *testing.T) {
require.Equal(t, ca.tableIDs, tableids)
}
}

func TestExtractorInPreparedStmt(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
tk := testkit.NewTestKit(t, store)

var cases = []struct {
prepared string
userVars []interface{}
params []interface{}
checker func(extractor plannercore.MemTablePredicateExtractor)
}{
{
prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id = ?",
userVars: []interface{}{1},
params: []interface{}{1},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.TiKVRegionStatusExtractor)
tableids := rse.GetTablesID()
slices.Sort(tableids)
require.Equal(t, []int64{1}, tableids)
},
},
{
prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id = ? or table_id = ?",
userVars: []interface{}{1, 2},
params: []interface{}{1, 2},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.TiKVRegionStatusExtractor)
tableids := rse.GetTablesID()
slices.Sort(tableids)
require.Equal(t, []int64{1, 2}, tableids)
},
},
{
prepared: "select * from information_schema.TIKV_REGION_STATUS where table_id in (?,?)",
userVars: []interface{}{1, 2},
params: []interface{}{1, 2},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.TiKVRegionStatusExtractor)
tableids := rse.GetTablesID()
slices.Sort(tableids)
require.Equal(t, []int64{1, 2}, tableids)
},
},
{
prepared: "select * from information_schema.COLUMNS where table_name like ?",
userVars: []interface{}{`"a%"`},
params: []interface{}{"a%"},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.ColumnsTableExtractor)
require.EqualValues(t, []string{"a%"}, rse.TableNamePatterns)
},
},
{
prepared: "select * from information_schema.tidb_hot_regions_history where update_time>=?",
userVars: []interface{}{"cast('2019-10-10 10:10:10' as datetime)"},
params: []interface{}{func() types.Time {
tt, err := types.ParseTimestamp(tk.Session().GetSessionVars().StmtCtx, "2019-10-10 10:10:10")
require.NoError(t, err)
return tt
}()},
checker: func(extractor plannercore.MemTablePredicateExtractor) {
rse := extractor.(*plannercore.HotRegionsHistoryTableExtractor)
require.Equal(t, timestamp(t, "2019-10-10 10:10:10"), rse.StartTime)
},
},
}

// text protocol
parser := parser.New()
for _, ca := range cases {
tk.MustExec(fmt.Sprintf("prepare stmt from '%s'", ca.prepared))
setStmt := "set "
exec := "execute stmt using "
for i, uv := range ca.userVars {
name := fmt.Sprintf("@a%d", i)
setStmt += fmt.Sprintf("%s=%v", name, uv)
exec += name
if i != len(ca.userVars)-1 {
setStmt += ","
exec += ","
}
}
tk.MustExec(setStmt)
stmt, err := parser.ParseOneStmt(exec, "", "")
require.NoError(t, err)
plan, _, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), stmt.(*ast.ExecuteStmt), dom.InfoSchema())
require.NoError(t, err)
extractor := plan.(*plannercore.Execute).Plan.(*plannercore.PhysicalMemTable).Extractor
ca.checker(extractor)
}

// binary protocol
for _, ca := range cases {
id, _, _, err := tk.Session().PrepareStmt(ca.prepared)
require.NoError(t, err)
prepStmt, err := tk.Session().GetSessionVars().GetPreparedStmtByID(id)
require.NoError(t, err)
params := expression.Args2Expressions4Test(ca.params...)
execStmt := &ast.ExecuteStmt{
BinaryArgs: params,
PrepStmt: prepStmt,
}
plan, _, err := planner.OptimizeExecStmt(context.Background(), tk.Session(), execStmt, dom.InfoSchema())
require.NoError(t, err)
extractor := plan.(*plannercore.Execute).Plan.(*plannercore.PhysicalMemTable).Extractor
ca.checker(extractor)
}
}

0 comments on commit 63d0112

Please sign in to comment.