diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index c93c3116f70ef..7d05bf55c788a 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -520,3 +520,24 @@ func (s *testSuite) TestBindingCache(c *C) { res := tk.MustQuery("show global bindings") c.Assert(len(res.Rows()), Equals, 2) } + +func (s *testSuite) TestDefaultDB(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("create table t(a int, b int, index idx(a))") + tk.MustExec("create global binding for select * from test.t using select * from test.t use index(idx)") + tk.MustExec("use mysql") + tk.MustQuery("select * from test.t") + // Even in another database, we could still use the bindings. + c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx") + tk.MustExec("drop global binding for select * from test.t") + tk.MustQuery("show global bindings").Check(testkit.Rows()) + + tk.MustExec("use test") + tk.MustExec("create session binding for select * from test.t using select * from test.t use index(idx)") + tk.MustExec("use mysql") + tk.MustQuery("select * from test.t") + // Even in another database, we could still use the bindings. + c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx") +} diff --git a/executor/bind.go b/executor/bind.go index 6849e4eec9b76..4b87bbd7de9c4 100644 --- a/executor/bind.go +++ b/executor/bind.go @@ -34,6 +34,7 @@ type SQLBindExec struct { bindSQL string charset string collation string + db string isGlobal bool bindAst ast.StmtNode } @@ -59,7 +60,7 @@ func (e *SQLBindExec) Next(ctx context.Context, req *chunk.Chunk) error { func (e *SQLBindExec) dropSQLBind() error { record := &bindinfo.BindRecord{ OriginalSQL: e.normdOrigSQL, - Db: e.ctx.GetSessionVars().CurrentDB, + Db: e.db, } if !e.isGlobal { handle := e.ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) @@ -73,7 +74,7 @@ func (e *SQLBindExec) createSQLBind() error { record := &bindinfo.BindRecord{ OriginalSQL: e.normdOrigSQL, BindSQL: e.bindSQL, - Db: e.ctx.GetSessionVars().CurrentDB, + Db: e.db, Charset: e.charset, Collation: e.collation, Status: bindinfo.Using, diff --git a/executor/builder.go b/executor/builder.go index ed78f5b734c7e..be57b771b302d 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -2295,6 +2295,7 @@ func (b *executorBuilder) buildSQLBindExec(v *plannercore.SQLBindPlan) Executor bindSQL: v.BindSQL, charset: v.Charset, collation: v.Collation, + db: v.Db, isGlobal: v.IsGlobal, bindAst: v.BindStmt, } diff --git a/executor/compiler.go b/executor/compiler.go index db8f211dd8f4d..aefebb1485d47 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -405,6 +405,9 @@ func addHint(ctx sessionctx.Context, stmtNode ast.StmtNode) ast.StmtNode { func addHintForSelect(hash, normdOrigSQL string, ctx sessionctx.Context, stmt ast.StmtNode) ast.StmtNode { sessionHandle := ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) bindRecord := sessionHandle.GetBindRecord(normdOrigSQL, ctx.GetSessionVars().CurrentDB) + if bindRecord == nil { + bindRecord = sessionHandle.GetBindRecord(normdOrigSQL, "") + } if bindRecord != nil { if bindRecord.Status == bindinfo.Invalid { return stmt diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index f1b4cfbe8d452..a9afde62b3c39 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -392,6 +392,7 @@ type SQLBindPlan struct { BindSQL string IsGlobal bool BindStmt ast.StmtNode + Db string Charset string Collation string } diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index a843022fdf557..bde5ea183e54e 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -388,6 +388,7 @@ func (b *PlanBuilder) buildDropBindPlan(v *ast.DropBindingStmt) (Plan, error) { SQLBindOp: OpSQLBindDrop, NormdOrigSQL: parser.Normalize(v.OriginSel.Text()), IsGlobal: v.GlobalScope, + Db: getDefaultDB(b.ctx, v.OriginSel), } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", nil) return p, nil @@ -401,6 +402,7 @@ func (b *PlanBuilder) buildCreateBindPlan(v *ast.CreateBindingStmt) (Plan, error BindSQL: v.HintedSel.Text(), IsGlobal: v.GlobalScope, BindStmt: v.HintedSel, + Db: getDefaultDB(b.ctx, v.OriginSel), Charset: charSet, Collation: collation, } @@ -408,6 +410,34 @@ func (b *PlanBuilder) buildCreateBindPlan(v *ast.CreateBindingStmt) (Plan, error return p, nil } +func getDefaultDB(ctx sessionctx.Context, sel ast.StmtNode) string { + implicitDB := &implicitDatabase{} + sel.Accept(implicitDB) + if implicitDB.hasImplicit { + return ctx.GetSessionVars().CurrentDB + } + return "" +} + +type implicitDatabase struct { + hasImplicit bool +} + +func (i *implicitDatabase) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + switch x := in.(type) { + case *ast.TableName: + if x.Schema.L == "" { + i.hasImplicit = true + } + return in, true + } + return in, false +} + +func (i *implicitDatabase) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true +} + // detectSelectAgg detects an aggregate function or GROUP BY clause. func (b *PlanBuilder) detectSelectAgg(sel *ast.SelectStmt) bool { if sel.GroupBy != nil { diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 7ec9a501253ec..4041a9dfb28f4 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -119,8 +119,10 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { EraseLastSemicolon(node.OriginSel) EraseLastSemicolon(node.HintedSel) p.checkBindGrammar(node) + return in, true case *ast.DropBindingStmt: EraseLastSemicolon(node.OriginSel) + return in, true case *ast.RecoverTableStmt: // The specified table in recover table statement maybe already been dropped. // So skip check table name here, otherwise, recover table [table_name] syntax will return