Skip to content

Commit

Permalink
bindinfo: set default db for bindings correctly (#14077) (#14548)
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx authored and sre-bot committed Jan 20, 2020
1 parent 8286441 commit 52dae01
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 2 deletions.
21 changes: 21 additions & 0 deletions bindinfo/bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,24 @@ func (s *testSuite) TestBindingCache(c *C) {
res := tk.MustQuery("show global bindings")
c.Assert(len(res.Rows()), Equals, 2)
}

func (s *testSuite) TestDefaultDB(c *C) {
tk := testkit.NewTestKit(c, s.store)
s.cleanBindingEnv(tk)
tk.MustExec("use test")
tk.MustExec("create table t(a int, b int, index idx(a))")
tk.MustExec("create global binding for select * from test.t using select * from test.t use index(idx)")
tk.MustExec("use mysql")
tk.MustQuery("select * from test.t")
// Even in another database, we could still use the bindings.
c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx")
tk.MustExec("drop global binding for select * from test.t")
tk.MustQuery("show global bindings").Check(testkit.Rows())

tk.MustExec("use test")
tk.MustExec("create session binding for select * from test.t using select * from test.t use index(idx)")
tk.MustExec("use mysql")
tk.MustQuery("select * from test.t")
// Even in another database, we could still use the bindings.
c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx")
}
5 changes: 3 additions & 2 deletions executor/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type SQLBindExec struct {
bindSQL string
charset string
collation string
db string
isGlobal bool
bindAst ast.StmtNode
}
Expand All @@ -59,7 +60,7 @@ func (e *SQLBindExec) Next(ctx context.Context, req *chunk.Chunk) error {
func (e *SQLBindExec) dropSQLBind() error {
record := &bindinfo.BindRecord{
OriginalSQL: e.normdOrigSQL,
Db: e.ctx.GetSessionVars().CurrentDB,
Db: e.db,
}
if !e.isGlobal {
handle := e.ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle)
Expand All @@ -73,7 +74,7 @@ func (e *SQLBindExec) createSQLBind() error {
record := &bindinfo.BindRecord{
OriginalSQL: e.normdOrigSQL,
BindSQL: e.bindSQL,
Db: e.ctx.GetSessionVars().CurrentDB,
Db: e.db,
Charset: e.charset,
Collation: e.collation,
Status: bindinfo.Using,
Expand Down
1 change: 1 addition & 0 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2295,6 +2295,7 @@ func (b *executorBuilder) buildSQLBindExec(v *plannercore.SQLBindPlan) Executor
bindSQL: v.BindSQL,
charset: v.Charset,
collation: v.Collation,
db: v.Db,
isGlobal: v.IsGlobal,
bindAst: v.BindStmt,
}
Expand Down
3 changes: 3 additions & 0 deletions executor/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ func addHint(ctx sessionctx.Context, stmtNode ast.StmtNode) ast.StmtNode {
func addHintForSelect(hash, normdOrigSQL string, ctx sessionctx.Context, stmt ast.StmtNode) ast.StmtNode {
sessionHandle := ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle)
bindRecord := sessionHandle.GetBindRecord(normdOrigSQL, ctx.GetSessionVars().CurrentDB)
if bindRecord == nil {
bindRecord = sessionHandle.GetBindRecord(normdOrigSQL, "")
}
if bindRecord != nil {
if bindRecord.Status == bindinfo.Invalid {
return stmt
Expand Down
1 change: 1 addition & 0 deletions planner/core/common_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ type SQLBindPlan struct {
BindSQL string
IsGlobal bool
BindStmt ast.StmtNode
Db string
Charset string
Collation string
}
Expand Down
30 changes: 30 additions & 0 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ func (b *PlanBuilder) buildDropBindPlan(v *ast.DropBindingStmt) (Plan, error) {
SQLBindOp: OpSQLBindDrop,
NormdOrigSQL: parser.Normalize(v.OriginSel.Text()),
IsGlobal: v.GlobalScope,
Db: getDefaultDB(b.ctx, v.OriginSel),
}
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", nil)
return p, nil
Expand All @@ -401,13 +402,42 @@ func (b *PlanBuilder) buildCreateBindPlan(v *ast.CreateBindingStmt) (Plan, error
BindSQL: v.HintedSel.Text(),
IsGlobal: v.GlobalScope,
BindStmt: v.HintedSel,
Db: getDefaultDB(b.ctx, v.OriginSel),
Charset: charSet,
Collation: collation,
}
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", nil)
return p, nil
}

func getDefaultDB(ctx sessionctx.Context, sel ast.StmtNode) string {
implicitDB := &implicitDatabase{}
sel.Accept(implicitDB)
if implicitDB.hasImplicit {
return ctx.GetSessionVars().CurrentDB
}
return ""
}

type implicitDatabase struct {
hasImplicit bool
}

func (i *implicitDatabase) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch x := in.(type) {
case *ast.TableName:
if x.Schema.L == "" {
i.hasImplicit = true
}
return in, true
}
return in, false
}

func (i *implicitDatabase) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}

// detectSelectAgg detects an aggregate function or GROUP BY clause.
func (b *PlanBuilder) detectSelectAgg(sel *ast.SelectStmt) bool {
if sel.GroupBy != nil {
Expand Down
2 changes: 2 additions & 0 deletions planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
EraseLastSemicolon(node.OriginSel)
EraseLastSemicolon(node.HintedSel)
p.checkBindGrammar(node)
return in, true
case *ast.DropBindingStmt:
EraseLastSemicolon(node.OriginSel)
return in, true
case *ast.RecoverTableStmt:
// The specified table in recover table statement maybe already been dropped.
// So skip check table name here, otherwise, recover table [table_name] syntax will return
Expand Down

0 comments on commit 52dae01

Please sign in to comment.