diff --git a/bindinfo/bind.go b/bindinfo/bind.go new file mode 100644 index 0000000000000..9d70aadf198a4 --- /dev/null +++ b/bindinfo/bind.go @@ -0,0 +1,169 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package bindinfo + +import "github.com/pingcap/parser/ast" + +// BindHint will add hints for originStmt according to hintedStmt' hints. +func BindHint(originStmt, hintedStmt ast.StmtNode) ast.StmtNode { + switch x := originStmt.(type) { + case *ast.SelectStmt: + return selectBind(x, hintedStmt.(*ast.SelectStmt)) + default: + return originStmt + } +} + +func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt { + if hintedNode.TableHints != nil { + originalNode.TableHints = hintedNode.TableHints + } + if originalNode.From != nil { + originalNode.From.TableRefs = resultSetNodeBind(originalNode.From.TableRefs, hintedNode.From.TableRefs).(*ast.Join) + } + if originalNode.Where != nil { + originalNode.Where = exprBind(originalNode.Where, hintedNode.Where).(ast.ExprNode) + } + + if originalNode.Having != nil { + originalNode.Having.Expr = exprBind(originalNode.Having.Expr, hintedNode.Having.Expr) + } + + if originalNode.OrderBy != nil { + originalNode.OrderBy = orderByBind(originalNode.OrderBy, hintedNode.OrderBy) + } + + if originalNode.Fields != nil { + origFields := originalNode.Fields.Fields + hintFields := hintedNode.Fields.Fields + for idx := range origFields { + origFields[idx].Expr = exprBind(origFields[idx].Expr, hintFields[idx].Expr) + } + } + return originalNode +} + +func orderByBind(originalNode, hintedNode *ast.OrderByClause) *ast.OrderByClause { + for idx := 0; idx < len(originalNode.Items); idx++ { + originalNode.Items[idx].Expr = exprBind(originalNode.Items[idx].Expr, hintedNode.Items[idx].Expr) + } + return originalNode +} + +func exprBind(originalNode, hintedNode ast.ExprNode) ast.ExprNode { + switch v := originalNode.(type) { + case *ast.SubqueryExpr: + if v.Query != nil { + v.Query = resultSetNodeBind(v.Query, hintedNode.(*ast.SubqueryExpr).Query) + } + case *ast.ExistsSubqueryExpr: + if v.Sel != nil { + v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.ExistsSubqueryExpr).Sel.(*ast.SubqueryExpr).Query) + } + case *ast.PatternInExpr: + if v.Sel != nil { + v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query) + } + case *ast.BinaryOperationExpr: + if v.L != nil { + v.L = exprBind(v.L, hintedNode.(*ast.BinaryOperationExpr).L) + } + if v.R != nil { + v.R = exprBind(v.R, hintedNode.(*ast.BinaryOperationExpr).R) + } + case *ast.IsNullExpr: + if v.Expr != nil { + v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsNullExpr).Expr) + } + case *ast.IsTruthExpr: + if v.Expr != nil { + v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsTruthExpr).Expr) + } + case *ast.PatternLikeExpr: + if v.Pattern != nil { + v.Pattern = exprBind(v.Pattern, hintedNode.(*ast.PatternLikeExpr).Pattern) + } + case *ast.CompareSubqueryExpr: + if v.L != nil { + v.L = exprBind(v.L, hintedNode.(*ast.CompareSubqueryExpr).L) + } + if v.R != nil { + v.R = exprBind(v.R, hintedNode.(*ast.CompareSubqueryExpr).R) + } + case *ast.BetweenExpr: + if v.Left != nil { + v.Left = exprBind(v.Left, hintedNode.(*ast.BetweenExpr).Left) + } + if v.Right != nil { + v.Right = exprBind(v.Right, hintedNode.(*ast.BetweenExpr).Right) + } + case *ast.UnaryOperationExpr: + if v.V != nil { + v.V = exprBind(v.V, hintedNode.(*ast.UnaryOperationExpr).V) + } + case *ast.CaseExpr: + if v.Value != nil { + v.Value = exprBind(v.Value, hintedNode.(*ast.CaseExpr).Value) + } + if v.ElseClause != nil { + v.ElseClause = exprBind(v.ElseClause, hintedNode.(*ast.CaseExpr).ElseClause) + } + } + return originalNode +} + +func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSetNode { + switch x := originalNode.(type) { + case *ast.Join: + return joinBind(x, hintedNode.(*ast.Join)) + case *ast.TableSource: + ts, _ := hintedNode.(*ast.TableSource) + switch v := x.Source.(type) { + case *ast.SelectStmt: + x.Source = selectBind(v, ts.Source.(*ast.SelectStmt)) + case *ast.UnionStmt: + x.Source = unionSelectBind(v, hintedNode.(*ast.TableSource).Source.(*ast.UnionStmt)) + case *ast.TableName: + x.Source.(*ast.TableName).IndexHints = ts.Source.(*ast.TableName).IndexHints + } + return x + case *ast.SelectStmt: + return selectBind(x, hintedNode.(*ast.SelectStmt)) + case *ast.UnionStmt: + return unionSelectBind(x, hintedNode.(*ast.UnionStmt)) + default: + return x + } +} + +func joinBind(originalNode, hintedNode *ast.Join) *ast.Join { + if originalNode.Left != nil { + originalNode.Left = resultSetNodeBind(originalNode.Left, hintedNode.Left) + } + + if hintedNode.Right != nil { + originalNode.Right = resultSetNodeBind(originalNode.Right, hintedNode.Right) + } + + return originalNode +} + +func unionSelectBind(originalNode, hintedNode *ast.UnionStmt) ast.ResultSetNode { + selects := originalNode.SelectList.Selects + for i := len(selects) - 1; i >= 0; i-- { + originalNode.SelectList.Selects[i] = selectBind(selects[i], hintedNode.SelectList.Selects[i]) + } + + return originalNode +} diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index 9b764b9c6f12f..e37ae0f45eff3 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -280,3 +280,147 @@ func (s *testSuite) TestSessionBinding(c *C) { c.Check(bindData.OriginalSQL, Equals, "select * from t where i > ?") c.Check(bindData.Status, Equals, "deleted") } + +func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1(id int)") + tk.MustExec("create table t2(id int)") + + tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "HashLeftJoin_8 12487.50 root inner join, inner:TableReader_15, equal:[eq(test.t1.id, test.t2.id)]", + "├─TableReader_12 9990.00 root data:Selection_11", + "│ └─Selection_11 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_15 9990.00 root data:Selection_14", + " └─Selection_14 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + + tk.MustQuery("explain SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", + "├─Sort_11 9990.00 root test.t1.id:asc", + "│ └─TableReader_10 9990.00 root data:Selection_9", + "│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─Sort_15 9990.00 root test.t2.id:asc", + " └─TableReader_14 9990.00 root data:Selection_13", + " └─Selection_13 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + + tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") + + tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", + "├─Sort_11 9990.00 root test.t1.id:asc", + "│ └─TableReader_10 9990.00 root data:Selection_9", + "│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─Sort_15 9990.00 root test.t2.id:asc", + " └─TableReader_14 9990.00 root data:Selection_13", + " └─Selection_13 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + + tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") + + tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "HashLeftJoin_8 12487.50 root inner join, inner:TableReader_15, equal:[eq(test.t1.id, test.t2.id)]", + "├─TableReader_12 9990.00 root data:Selection_11", + "│ └─Selection_11 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_15 9990.00 root data:Selection_14", + " └─Selection_14 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) +} + +func (s *testSuite) TestExplain(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1(id int)") + tk.MustExec("create table t2(id int)") + + tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "HashLeftJoin_8 12487.50 root inner join, inner:TableReader_15, equal:[eq(test.t1.id, test.t2.id)]", + "├─TableReader_12 9990.00 root data:Selection_11", + "│ └─Selection_11 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_15 9990.00 root data:Selection_14", + " └─Selection_14 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + + tk.MustQuery("explain SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", + "├─Sort_11 9990.00 root test.t1.id:asc", + "│ └─TableReader_10 9990.00 root data:Selection_9", + "│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─Sort_15 9990.00 root test.t2.id:asc", + " └─TableReader_14 9990.00 root data:Selection_13", + " └─Selection_13 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + + tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") + + tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", + "├─Sort_11 9990.00 root test.t1.id:asc", + "│ └─TableReader_10 9990.00 root data:Selection_9", + "│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─Sort_15 9990.00 root test.t2.id:asc", + " └─TableReader_14 9990.00 root data:Selection_13", + " └─Selection_13 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + + tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") +} + +func (s *testSuite) TestErrorBind(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(i int, s varchar(20))") + tk.MustExec("create table t1(i int, s varchar(20))") + tk.MustExec("create index index_t on t(i,s)") + + _, err := tk.Exec("create global binding for select * from t where i>100 using select * from t use index(index_t) where i>100") + c.Assert(err, IsNil, Commentf("err %v", err)) + + bindData := s.domain.BindHandle().GetBindRecord("select * from t where i > ?", "test") + c.Check(bindData, NotNil) + c.Check(bindData.OriginalSQL, Equals, "select * from t where i > ?") + c.Check(bindData.BindSQL, Equals, "select * from t use index(index_t) where i>100") + c.Check(bindData.Db, Equals, "test") + c.Check(bindData.Status, Equals, "using") + c.Check(bindData.Charset, NotNil) + c.Check(bindData.Collation, NotNil) + c.Check(bindData.CreateTime, NotNil) + c.Check(bindData.UpdateTime, NotNil) + + tk.MustExec("drop index index_t on t") + _, err = tk.Exec("select * from t where i > 10") + c.Check(err, IsNil) + + s.domain.BindHandle().DropInvalidBindRecord() + + rs, err := tk.Exec("show global bindings") + c.Assert(err, IsNil) + chk := rs.NewRecordBatch() + err = rs.Next(context.TODO(), chk) + c.Check(err, IsNil) + c.Check(chk.NumRows(), Equals, 0) +} diff --git a/bindinfo/cache.go b/bindinfo/cache.go index b8ef2583c0f69..a4c2785eb9c64 100644 --- a/bindinfo/cache.go +++ b/bindinfo/cache.go @@ -20,16 +20,18 @@ import ( ) const ( - // using is the bind info's in use status. - using = "using" + // Using is the bind info's in use status. + Using = "using" // deleted is the bind info's deleted status. deleted = "deleted" + // Invalid is the bind info's invalid status. + Invalid = "invalid" ) // BindMeta stores the basic bind info and bindSql astNode. type BindMeta struct { *BindRecord - ast ast.StmtNode //ast will be used to do query sql bind check + Ast ast.StmtNode //ast will be used to do query sql bind check } // cache is a k-v map, key is original sql, value is a slice of BindMeta. diff --git a/bindinfo/handle.go b/bindinfo/handle.go index f0ba865a779eb..fc33f2cd3442c 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -20,6 +20,7 @@ import ( "go.uber.org/zap" "sync" "sync/atomic" + "time" "github.com/pingcap/parser" "github.com/pingcap/parser/mysql" @@ -60,15 +61,28 @@ type BindHandle struct { atomic.Value } + // invalidBindRecordMap indicates the invalid bind records found during querying. + // A record will be deleted from this map, after 2 bind-lease, after it is dropped from the kv. + invalidBindRecordMap struct { + sync.Mutex + atomic.Value + } + parser *parser.Parser lastUpdateTime types.Time } +type invalidBindRecordMap struct { + bindRecord *BindRecord + droppedTime time.Time +} + // NewBindHandle creates a new BindHandle. func NewBindHandle(ctx sessionctx.Context, parser *parser.Parser) *BindHandle { handle := &BindHandle{parser: parser} handle.sctx.Context = ctx handle.bindInfo.Value.Store(make(cache, 32)) + handle.invalidBindRecordMap.Value.Store(make(map[string]*invalidBindRecordMap)) return handle } @@ -106,7 +120,7 @@ func (h *BindHandle) Update(fullLoad bool) (err error) { } newCache.removeStaleBindMetas(hash, meta) - if meta.Status == using { + if meta.Status == Using { newCache[hash] = append(newCache[hash], meta) } } @@ -163,7 +177,7 @@ func (h *BindHandle) AddBindRecord(record *BindRecord) (err error) { Fsp: 3, } record.UpdateTime = record.CreateTime - record.Status = using + record.Status = Using record.BindSQL = h.getEscapeCharacter(record.BindSQL) // insert the BindRecord to the storage. @@ -217,6 +231,44 @@ func (h *BindHandle) DropBindRecord(record *BindRecord) (err error) { return err } +// DropInvalidBindRecord execute the drop bindRecord task. +func (h *BindHandle) DropInvalidBindRecord() { + invalidBindRecordMap := copyInvalidBindRecordMap(h.invalidBindRecordMap.Load().(map[string]*invalidBindRecordMap)) + for key, invalidBindRecord := range invalidBindRecordMap { + if invalidBindRecord.droppedTime.IsZero() { + err := h.DropBindRecord(invalidBindRecord.bindRecord) + if err != nil { + logutil.Logger(context.Background()).Error("DropInvalidBindRecord failed", zap.Error(err)) + } + invalidBindRecord.droppedTime = time.Now() + continue + } + + if time.Since(invalidBindRecord.droppedTime) > 6*time.Second { + delete(invalidBindRecordMap, key) + } + } + h.invalidBindRecordMap.Store(invalidBindRecordMap) +} + +// AddDropInvalidBindTask add bindRecord to invalidBindRecordMap when the bindRecord need to be deleted. +func (h *BindHandle) AddDropInvalidBindTask(invalidBindRecord *BindRecord) { + key := invalidBindRecord.OriginalSQL + ":" + invalidBindRecord.Db + if _, ok := h.invalidBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)[key]; ok { + return + } + h.invalidBindRecordMap.Lock() + defer h.invalidBindRecordMap.Unlock() + if _, ok := h.invalidBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)[key]; ok { + return + } + newMap := copyInvalidBindRecordMap(h.invalidBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)) + newMap[key] = &invalidBindRecordMap{ + bindRecord: invalidBindRecord, + } + h.invalidBindRecordMap.Store(newMap) +} + // Size return the size of bind info cache. func (h *BindHandle) Size() int { size := 0 @@ -246,7 +298,7 @@ func (h *BindHandle) newBindMeta(record *BindRecord) (hash string, meta *BindMet if err != nil { return "", nil, err } - meta = &BindMeta{BindRecord: record, ast: stmtNodes[0]} + meta = &BindMeta{BindRecord: record, Ast: stmtNodes[0]} return hash, meta, nil } @@ -328,6 +380,14 @@ func (c cache) copy() cache { return newCache } +func copyInvalidBindRecordMap(oldMap map[string]*invalidBindRecordMap) map[string]*invalidBindRecordMap { + newMap := make(map[string]*invalidBindRecordMap, len(oldMap)) + for k, v := range oldMap { + newMap[k] = v + } + return newMap +} + func (c cache) getBindRecord(normdOrigSQL, db string) *BindMeta { hash := parser.DigestHash(normdOrigSQL) bindRecords := c[hash] diff --git a/bindinfo/session_handle.go b/bindinfo/session_handle.go index 90b4d8ac3c457..f343b3ca8e24d 100644 --- a/bindinfo/session_handle.go +++ b/bindinfo/session_handle.go @@ -48,13 +48,12 @@ func (h *SessionHandle) newBindMeta(record *BindRecord) (hash string, meta *Bind if err != nil { return "", nil, err } - meta = &BindMeta{BindRecord: record, ast: stmtNodes[0]} + meta = &BindMeta{BindRecord: record, Ast: stmtNodes[0]} return hash, meta, nil } // AddBindRecord new a BindRecord with BindMeta, add it to the cache. func (h *SessionHandle) AddBindRecord(record *BindRecord) error { - record.Status = using record.CreateTime = types.Time{ Time: types.FromGoTime(time.Now()), Type: mysql.TypeDatetime, diff --git a/domain/domain.go b/domain/domain.go index 45b8a1d5cec5a..008348cee7eff 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -791,6 +791,12 @@ func (do *Domain) LoadBindInfoLoop(ctx sessionctx.Context, parser *parser.Parser return err } + do.loadBindInfoLoop() + do.handleInvalidBindTaskLoop() + return nil +} + +func (do *Domain) loadBindInfoLoop() { duration := 3 * time.Second do.wg.Add(1) go func() { @@ -802,13 +808,29 @@ func (do *Domain) LoadBindInfoLoop(ctx sessionctx.Context, parser *parser.Parser return case <-time.After(duration): } - err = do.bindHandle.Update(false) + err := do.bindHandle.Update(false) if err != nil { logutil.Logger(context.Background()).Error("update bindinfo failed", zap.Error(err)) } } }() - return nil +} + +func (do *Domain) handleInvalidBindTaskLoop() { + handleInvalidTaskDuration := 3 * time.Second + do.wg.Add(1) + go func() { + defer do.wg.Done() + defer recoverInDomain("loadBindInfoLoop-dropInvalidBindInfo", false) + for { + select { + case <-do.exit: + return + case <-time.After(handleInvalidTaskDuration): + } + do.bindHandle.DropInvalidBindRecord() + } + }() } // StatsHandle returns the statistic handle. diff --git a/executor/bind.go b/executor/bind.go index 2e3ad0f8aa326..d2c2034851bde 100644 --- a/executor/bind.go +++ b/executor/bind.go @@ -76,6 +76,7 @@ func (e *SQLBindExec) createSQLBind() error { Db: e.ctx.GetSessionVars().CurrentDB, Charset: e.charset, Collation: e.collation, + Status: bindinfo.Using, } if !e.isGlobal { handle := e.ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) diff --git a/executor/compiler.go b/executor/compiler.go index 23b2239475ca9..f233bdfe833cd 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -16,10 +16,14 @@ package executor import ( "context" "fmt" + "strings" "github.com/opentracing/opentracing-go" + "github.com/pingcap/parser" "github.com/pingcap/parser/ast" + "github.com/pingcap/tidb/bindinfo" "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/planner" @@ -49,11 +53,24 @@ type Compiler struct { // Compile compiles an ast.StmtNode to a physical plan. func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (*ExecStmt, error) { + return c.compile(ctx, stmtNode, false) +} + +// SkipBindCompile compiles an ast.StmtNode to a physical plan without SQL bind. +func (c *Compiler) SkipBindCompile(ctx context.Context, node ast.StmtNode) (*ExecStmt, error) { + return c.compile(ctx, node, true) +} + +func (c *Compiler) compile(ctx context.Context, stmtNode ast.StmtNode, skipBind bool) (*ExecStmt, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("executor.Compile", opentracing.ChildOf(span.Context())) defer span1.Finish() } + if !skipBind { + stmtNode = addHint(c.Ctx, stmtNode) + } + infoSchema := GetInfoSchema(c.Ctx) if err := plannercore.Preprocess(c.Ctx, stmtNode, infoSchema); err != nil { return nil, err @@ -367,3 +384,46 @@ func GetInfoSchema(ctx sessionctx.Context) infoschema.InfoSchema { } return is } + +func addHint(ctx sessionctx.Context, stmtNode ast.StmtNode) ast.StmtNode { + if ctx.Value(bindinfo.SessionBindInfoKeyType) == nil { //when the domain is initializing, the bind will be nil. + return stmtNode + } + switch x := stmtNode.(type) { + case *ast.ExplainStmt: + switch x.Stmt.(type) { + case *ast.SelectStmt: + normalizeExplainSQL := parser.Normalize(x.Text()) + idx := strings.Index(normalizeExplainSQL, "select") + normalizeSQL := normalizeExplainSQL[idx:] + x.Stmt = addHintForSelect(normalizeSQL, ctx, x.Stmt) + } + return x + case *ast.SelectStmt: + return addHintForSelect(parser.Normalize(x.Text()), ctx, x) + default: + return stmtNode + } +} + +func addHintForSelect(normdOrigSQL string, ctx sessionctx.Context, stmt ast.StmtNode) ast.StmtNode { + sessionHandle := ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) + bindRecord := sessionHandle.GetBindRecord(normdOrigSQL, ctx.GetSessionVars().CurrentDB) + if bindRecord != nil { + if bindRecord.Status == bindinfo.Invalid { + return stmt + } + if bindRecord.Status == bindinfo.Using { + return bindinfo.BindHint(stmt, bindRecord.Ast) + } + } + globalHandle := domain.GetDomain(ctx).BindHandle() + bindRecord = globalHandle.GetBindRecord(normdOrigSQL, ctx.GetSessionVars().CurrentDB) + if bindRecord == nil { + bindRecord = globalHandle.GetBindRecord(normdOrigSQL, "") + } + if bindRecord != nil { + return bindinfo.BindHint(stmt, bindRecord.Ast) + } + return stmt +} diff --git a/session/session.go b/session/session.go index 5d297a7476367..1f504b9dd4ddd 100644 --- a/session/session.go +++ b/session/session.go @@ -991,8 +991,9 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec sessionExecuteParseDurationGeneral.Observe(time.Since(startTS).Seconds()) } + var tempStmtNodes []ast.StmtNode compiler := executor.Compiler{Ctx: s} - for _, stmtNode := range stmtNodes { + for idx, stmtNode := range stmtNodes { s.PrepareTxnCtx(ctx) // Step2: Transform abstract syntax tree to a physical plan(stored in executor.ExecStmt). @@ -1003,11 +1004,22 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec } stmt, err := compiler.Compile(ctx, stmtNode) if err != nil { - s.rollbackOnError(ctx) - logutil.Logger(ctx).Warn("compile sql error", - zap.Error(err), - zap.String("sql", sql)) - return nil, err + if tempStmtNodes == nil { + tempStmtNodes, warns, err = s.ParseSQL(ctx, sql, charsetInfo, collation) + if err != nil || warns != nil { + //just skip errcheck, because parse will not return an error. + } + } + stmtNode = tempStmtNodes[idx] + stmt, err = compiler.SkipBindCompile(ctx, stmtNode) + if err != nil { + s.rollbackOnError(ctx) + logutil.Logger(ctx).Warn("compile sql error", + zap.Error(err), + zap.String("sql", sql)) + return nil, err + } + s.handleInvalidBindRecord(ctx, stmtNode) } if isInternal { sessionExecuteCompileDurationInternal.Observe(time.Since(startTS).Seconds()) @@ -1033,6 +1045,59 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec return recordSets, nil } +func (s *session) handleInvalidBindRecord(ctx context.Context, stmtNode ast.StmtNode) { + var normdOrigSQL string + switch x := stmtNode.(type) { + case *ast.ExplainStmt: + switch x.Stmt.(type) { + case *ast.SelectStmt: + normalizeExplainSQL := parser.Normalize(x.Text()) + idx := strings.Index(normalizeExplainSQL, "select") + normdOrigSQL = normalizeExplainSQL[idx:] + default: + return + } + case *ast.SelectStmt: + normdOrigSQL = parser.Normalize(x.Text()) + default: + return + } + sessionHandle := s.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) + bindMeta := sessionHandle.GetBindRecord(normdOrigSQL, s.GetSessionVars().CurrentDB) + if bindMeta != nil { + bindMeta.Status = bindinfo.Invalid + return + } + + globalHandle := domain.GetDomain(s).BindHandle() + bindMeta = globalHandle.GetBindRecord(normdOrigSQL, s.GetSessionVars().CurrentDB) + if bindMeta == nil { + bindMeta = globalHandle.GetBindRecord(normdOrigSQL, "") + } + if bindMeta != nil { + record := &bindinfo.BindRecord{ + OriginalSQL: bindMeta.OriginalSQL, + BindSQL: bindMeta.BindSQL, + Db: s.GetSessionVars().CurrentDB, + Charset: bindMeta.Charset, + Collation: bindMeta.Collation, + Status: bindinfo.Invalid, + } + + err := sessionHandle.AddBindRecord(record) + if err != nil { + logutil.Logger(ctx).Warn("handleInvalidBindRecord failed", zap.Error(err)) + } + + globalHandle := domain.GetDomain(s).BindHandle() + dropBindRecord := &bindinfo.BindRecord{ + OriginalSQL: bindMeta.OriginalSQL, + Db: bindMeta.Db, + } + globalHandle.AddDropInvalidBindTask(dropBindRecord) + } +} + // rollbackOnError makes sure the next statement starts a new transaction with the latest InfoSchema. func (s *session) rollbackOnError(ctx context.Context) { if !s.sessionVars.InTxn() {