Skip to content

Commit

Permalink
planner: fix panic in extractSelectAndNormalizeDigest (pingcap#22333)
Browse files Browse the repository at this point in the history
  • Loading branch information
qw4990 authored Jan 11, 2021
1 parent 809a375 commit e234403
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 37 deletions.
11 changes: 11 additions & 0 deletions bindinfo/bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2038,3 +2038,14 @@ func (s *testSuite) TestCaptureWithZeroSlowLogThreshold(c *C) {
c.Assert(len(rows), Equals, 1)
c.Assert(rows[0][0], Equals, "select * from test . t")
}

func (s *testSuite) TestExplainTableStmts(c *C) {
tk := testkit.NewTestKit(c, s.store)
s.cleanBindingEnv(tk)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(id int, value decimal(5,2))")
tk.MustExec("table t")
tk.MustExec("explain table t")
tk.MustExec("desc table t")
}
8 changes: 8 additions & 0 deletions planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ func EraseLastSemicolon(stmt ast.StmtNode) {
}
}

// EraseLastSemicolonInSQL removes last semicolon of the SQL.
func EraseLastSemicolonInSQL(sql string) string {
if len(sql) > 0 && sql[len(sql)-1] == ';' {
return sql[:len(sql)-1]
}
return sql
}

const (
// TypeInvalid for unexpected types.
TypeInvalid byte = iota
Expand Down
61 changes: 24 additions & 37 deletions planner/optimize.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in
if !ok {
return bestPlan, names, nil
}
bindRecord, scope := getBindRecord(sctx, stmtNode)
bindRecord, scope, err := getBindRecord(sctx, stmtNode)
if err != nil {
return nil, nil, err
}
if bindRecord == nil {
return bestPlan, names, nil
}
Expand Down Expand Up @@ -274,7 +277,7 @@ func optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in
return finalPlan, names, cost, err
}

func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string) (ast.StmtNode, string, string) {
func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string) (ast.StmtNode, string, string, error) {
switch x := stmtNode.(type) {
case *ast.ExplainStmt:
// This function is only used to find bind record.
Expand All @@ -283,35 +286,19 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string)
// The difference between them is whether len(x.Text()) is empty. They cannot be distinguished by stmt.restore.
// For these cases, we need return "" as normalize SQL and hash.
if len(x.Text()) == 0 {
return x.Stmt, "", ""
return x.Stmt, "", "", nil
}
switch x.Stmt.(type) {
case *ast.SelectStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt:
plannercore.EraseLastSemicolon(x)
var normalizeExplainSQL string
var normalizeSQL string
if specifiledDB != "" {
normalizeExplainSQL = parser.Normalize(utilparser.RestoreWithDefaultDB(x, specifiledDB))
normalizeSQL = parser.Normalize(utilparser.RestoreWithDefaultDB(x.Stmt, specifiledDB))
} else {
normalizeExplainSQL = parser.Normalize(x.Text())
normalizeSQL = parser.Normalize(x.Text())
}
idx := int(0)
switch n := x.Stmt.(type) {
case *ast.SelectStmt:
idx = strings.Index(normalizeExplainSQL, "select")
case *ast.DeleteStmt:
idx = strings.Index(normalizeExplainSQL, "delete")
case *ast.UpdateStmt:
idx = strings.Index(normalizeExplainSQL, "update")
case *ast.InsertStmt:
if n.IsReplace {
idx = strings.Index(normalizeExplainSQL, "replace")
} else {
idx = strings.Index(normalizeExplainSQL, "insert")
}
}
normalizeSQL := normalizeExplainSQL[idx:]
normalizeSQL = plannercore.EraseLastSemicolonInSQL(normalizeSQL)
hash := parser.DigestNormalized(normalizeSQL)
return x.Stmt, normalizeSQL, hash
return x.Stmt, normalizeSQL, hash, nil
case *ast.SetOprStmt:
plannercore.EraseLastSemicolon(x)
var normalizeExplainSQL string
Expand All @@ -327,7 +314,7 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string)
}
normalizeSQL := normalizeExplainSQL[idx:]
hash := parser.DigestNormalized(normalizeSQL)
return x.Stmt, normalizeSQL, hash
return x.Stmt, normalizeSQL, hash, nil
}
case *ast.SelectStmt, *ast.SetOprStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt:
plannercore.EraseLastSemicolon(x)
Expand All @@ -337,42 +324,42 @@ func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string)
// The difference between them is whether len(x.Text()) is empty. They cannot be distinguished by stmt.restore.
// For these cases, we need return "" as normalize SQL and hash.
if len(x.Text()) == 0 {
return x, "", ""
return x, "", "", nil
}
var normalizedSQL, hash string
if specifiledDB != "" {
normalizedSQL, hash = parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(x, specifiledDB))
} else {
normalizedSQL, hash = parser.NormalizeDigest(x.Text())
}
return x, normalizedSQL, hash
return x, normalizedSQL, hash, nil
}
return nil, "", ""
return nil, "", "", nil
}

func getBindRecord(ctx sessionctx.Context, stmt ast.StmtNode) (*bindinfo.BindRecord, string) {
func getBindRecord(ctx sessionctx.Context, stmt ast.StmtNode) (*bindinfo.BindRecord, string, error) {
// When the domain is initializing, the bind will be nil.
if ctx.Value(bindinfo.SessionBindInfoKeyType) == nil {
return nil, ""
return nil, "", nil
}
stmtNode, normalizedSQL, hash := extractSelectAndNormalizeDigest(stmt, ctx.GetSessionVars().CurrentDB)
if stmtNode == nil {
return nil, ""
stmtNode, normalizedSQL, hash, err := extractSelectAndNormalizeDigest(stmt, ctx.GetSessionVars().CurrentDB)
if err != nil || stmtNode == nil {
return nil, "", err
}
sessionHandle := ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle)
bindRecord := sessionHandle.GetBindRecord(normalizedSQL, "")
if bindRecord != nil {
if bindRecord.HasUsingBinding() {
return bindRecord, metrics.ScopeSession
return bindRecord, metrics.ScopeSession, nil
}
return nil, ""
return nil, "", nil
}
globalHandle := domain.GetDomain(ctx).BindHandle()
if globalHandle == nil {
return nil, ""
return nil, "", nil
}
bindRecord = globalHandle.GetBindRecord(hash, normalizedSQL, "")
return bindRecord, metrics.ScopeGlobal
return bindRecord, metrics.ScopeGlobal, nil
}

func handleInvalidBindRecord(ctx context.Context, sctx sessionctx.Context, level string, bindRecord bindinfo.BindRecord) {
Expand Down

0 comments on commit e234403

Please sign in to comment.