Skip to content

Commit

Permalink
*: refactor the RestrictedSQLExecutor interface (pingcap#22579) (ping…
Browse files Browse the repository at this point in the history
…cap#22687)

Signed-off-by: ti-srebot <ti-srebot@pingcap.com>
Co-authored-by: tiancaiamao <tiancaiamao@gmail.com>
  • Loading branch information
ti-srebot and tiancaiamao authored Feb 22, 2021
1 parent 29947c3 commit 2e791d4
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 54 deletions.
9 changes: 6 additions & 3 deletions domain/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -1306,9 +1306,12 @@ func (do *Domain) NotifyUpdatePrivilege(ctx sessionctx.Context) {
}
}
// update locally
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(`FLUSH PRIVILEGES`)
if err != nil {
logutil.BgLogger().Error("unable to update privileges", zap.Error(err))
exec := ctx.(sqlexec.RestrictedSQLExecutor)
if stmt, err := exec.ParseWithParams(context.Background(), `FLUSH PRIVILEGES`); err == nil {
_, _, err := exec.ExecRestrictedStmt(context.Background(), stmt)
if err != nil {
logutil.BgLogger().Error("unable to update privileges", zap.Error(err))
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion server/sql_info_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (sh *sqlInfoFetcher) zipInfoForSQL(w http.ResponseWriter, r *http.Request)
timeoutString := r.FormValue("timeout")
curDB := strings.ToLower(r.FormValue("current_db"))
if curDB != "" {
_, err = sh.s.ExecuteInternal(context.Background(), "use %n", curDB)
_, err = sh.s.ExecuteInternal(reqCtx, "use %n", curDB)
if err != nil {
serveError(w, http.StatusInternalServerError, fmt.Sprintf("use database %v failed, err: %v", curDB, err))
return
Expand Down
103 changes: 86 additions & 17 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,6 @@ type Session interface {
ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error)
// Parse is deprecated, use ParseWithParams() instead.
Parse(ctx context.Context, sql string) ([]ast.StmtNode, error)
// ParseWithParams is the parameterized version of Parse: it will try to prevent injection under utf8mb4.
// It works like printf() in c, there are following format specifiers:
// 1. %?: automatic conversion by the type of arguments. E.g. []string -> ('s1','s2'..)
// 2. %%: output %
// 3. %n: for identifiers, for example ("use %n", db)
//
// Attention: it does not prevent you from doing parse("select '%?", ";SQL injection!;") => "select '';SQL injection!;'".
// One argument should be a standalone entity. It should not "concat" with other placeholders and characters.
// This function only saves you from processing potentially unsafe parameters.
ParseWithParams(ctx context.Context, sql string, args ...interface{}) ([]ast.StmtNode, error)
// ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements.
ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error)
String() string // String is used to debug.
Expand Down Expand Up @@ -1148,15 +1138,12 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter
logutil.Eventf(ctx, "execute: %s", sql)
}

stmtNodes, err := s.ParseWithParams(ctx, sql, args...)
stmtNode, err := s.ParseWithParams(ctx, sql, args...)
if err != nil {
return nil, err
}
if len(stmtNodes) != 1 {
return nil, errors.New("Executing multiple statements internally is not supported")
}

rs, err = s.ExecuteStmt(ctx, stmtNodes[0])
rs, err = s.ExecuteStmt(ctx, stmtNode)
if err != nil {
s.sessionVars.StmtCtx.AppendError(err)
}
Expand Down Expand Up @@ -1228,7 +1215,7 @@ func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error)
}

// ParseWithParams parses a query string, with arguments, to raw ast.StmtNode.
func (s *session) ParseWithParams(ctx context.Context, sql string, args ...interface{}) ([]ast.StmtNode, error) {
func (s *session) ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) {
sql, err := EscapeSQL(sql, args...)
if err != nil {
return nil, err
Expand All @@ -1250,6 +1237,9 @@ func (s *session) ParseWithParams(ctx context.Context, sql string, args ...inter
parseStartTime = time.Now()
stmts, warns, err = s.ParseSQL(ctx, sql, charsetInfo, collation)
}
if len(stmts) != 1 {
err = errors.New("run multiple statements internally is not supported")
}
if err != nil {
s.rollbackOnError(ctx)
// Only print log message when this SQL is from the user.
Expand All @@ -1272,7 +1262,86 @@ func (s *session) ParseWithParams(ctx context.Context, sql string, args ...inter
for _, warn := range warns {
s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}
return stmts, nil
return stmts[0], nil
}

// ExecRestrictedStmt implements RestrictedSQLExecutor interface.
func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) (
[]chunk.Row, []*ast.ResultField, error) {
var execOption sqlexec.ExecOption
for _, opt := range opts {
opt(&execOption)
}
// Use special session to execute the sql.
tmp, err := s.sysSessionPool().Get()
if err != nil {
return nil, nil, err
}
defer s.sysSessionPool().Put(tmp)
se := tmp.(*session)

startTime := time.Now()
// The special session will share the `InspectionTableCache` with current session
// if the current session in inspection mode.
if cache := s.sessionVars.InspectionTableCache; cache != nil {
se.sessionVars.InspectionTableCache = cache
defer func() { se.sessionVars.InspectionTableCache = nil }()
}
if ok := s.sessionVars.OptimizerUseInvisibleIndexes; ok {
se.sessionVars.OptimizerUseInvisibleIndexes = true
defer func() { se.sessionVars.OptimizerUseInvisibleIndexes = false }()
}
prePruneMode := se.sessionVars.PartitionPruneMode.Load()
defer func() {
if !execOption.IgnoreWarning {
if se != nil && se.GetSessionVars().StmtCtx.WarningCount() > 0 {
warnings := se.GetSessionVars().StmtCtx.GetWarnings()
s.GetSessionVars().StmtCtx.AppendWarnings(warnings)
}
}
se.sessionVars.PartitionPruneMode.Store(prePruneMode)
}()

if execOption.SnapshotTS != 0 {
se.sessionVars.SnapshotInfoschema, err = domain.GetDomain(s).GetSnapshotInfoSchema(execOption.SnapshotTS)
if err != nil {
return nil, nil, err
}
if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil {
return nil, nil, err
}
defer func() {
if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil {
logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err))
}
se.sessionVars.SnapshotInfoschema = nil
}()
}

// for analyze stmt we need let worker session follow user session that executing stmt.
se.sessionVars.PartitionPruneMode.Store(s.sessionVars.PartitionPruneMode.Load())
metrics.SessionRestrictedSQLCounter.Inc()

ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
rs, err := se.ExecuteStmt(ctx, stmtNode)
if err != nil {
se.sessionVars.StmtCtx.AppendError(err)
}
if rs == nil {
return nil, nil, err
}
defer func() {
if closeErr := rs.Close(); closeErr != nil {
err = closeErr
}
}()
var rows []chunk.Row
rows, err = drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}
metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds())
return rows, rs.Fields(), err
}

func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) {
Expand Down
12 changes: 6 additions & 6 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3835,33 +3835,33 @@ func (s *testSessionSerialSuite) TestDefaultWeekFormat(c *C) {
func (s *testSessionSerialSuite) TestParseWithParams(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
se := tk.Se
exec := se.(sqlexec.RestrictedSQLExecutor)

// test compatibility with ExcuteInternal
origin := se.GetSessionVars().InRestrictedSQL
se.GetSessionVars().InRestrictedSQL = true
defer func() {
se.GetSessionVars().InRestrictedSQL = origin
}()
_, err := se.ParseWithParams(context.Background(), "SELECT 4")
_, err := exec.ParseWithParams(context.Background(), "SELECT 4")
c.Assert(err, IsNil)

// test charset attack
stmts, err := se.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*")
stmts, err := exec.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*")
c.Assert(err, IsNil)
c.Assert(stmts, HasLen, 1)

var sb strings.Builder
ctx := format.NewRestoreCtx(0, &sb)
err = stmts[0].Restore(ctx)
err = stmts.Restore(ctx)
c.Assert(err, IsNil)
// FIXME: well... so the restore function is vulnerable...
c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\xbf' OR 1=1 /* LIMIT 1")

// test invalid sql
_, err = se.ParseWithParams(context.Background(), "SELECT")
_, err = exec.ParseWithParams(context.Background(), "SELECT")
c.Assert(err, ErrorMatches, ".*You have an error in your SQL syntax.*")

// test invalid arguments to escape
_, err = se.ParseWithParams(context.Background(), "SELECT %?")
_, err = exec.ParseWithParams(context.Background(), "SELECT %?")
c.Assert(err, ErrorMatches, "missing arguments.*")
}
26 changes: 12 additions & 14 deletions store/tikv/gcworker/gc_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func (w *GCWorker) prepare() (bool, uint64, error) {
ctx := context.Background()
se := createSession(w.store)
defer se.Close()
_, err := se.Execute(ctx, "BEGIN")
_, err := se.ExecuteInternal(ctx, "BEGIN")
if err != nil {
return false, 0, errors.Trace(err)
}
Expand Down Expand Up @@ -1622,7 +1622,7 @@ func (w *GCWorker) checkLeader() (bool, error) {
defer se.Close()

ctx := context.Background()
_, err := se.Execute(ctx, "BEGIN")
_, err := se.ExecuteInternal(ctx, "BEGIN")
if err != nil {
return false, errors.Trace(err)
}
Expand All @@ -1647,7 +1647,7 @@ func (w *GCWorker) checkLeader() (bool, error) {

se.RollbackTxn(ctx)

_, err = se.Execute(ctx, "BEGIN")
_, err = se.ExecuteInternal(ctx, "BEGIN")
if err != nil {
return false, errors.Trace(err)
}
Expand Down Expand Up @@ -1755,16 +1755,13 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) {
ctx := context.Background()
se := createSession(w.store)
defer se.Close()
stmt := fmt.Sprintf(`SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name='%s' FOR UPDATE`, key)
rs, err := se.Execute(ctx, stmt)
if len(rs) > 0 {
defer terror.Call(rs[0].Close)
}
rs, err := se.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name=%? FOR UPDATE`, key)
if err != nil {
return "", errors.Trace(err)
}
req := rs[0].NewChunk()
err = rs[0].Next(ctx, req)
defer terror.Call(rs.Close)
req := rs.NewChunk()
err = rs.Next(ctx, req)
if err != nil {
return "", errors.Trace(err)
}
Expand All @@ -1781,13 +1778,14 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) {
}

func (w *GCWorker) saveValueToSysTable(key, value string) error {
stmt := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s')
const stmt = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES (%?, %?, %?)
ON DUPLICATE KEY
UPDATE variable_value = '%[2]s', comment = '%[3]s'`,
key, value, gcVariableComments[key])
UPDATE variable_value = %?, comment = %?`
se := createSession(w.store)
defer se.Close()
_, err := se.Execute(context.Background(), stmt)
_, err := se.ExecuteInternal(context.Background(), stmt,
key, value, gcVariableComments[key],
value, gcVariableComments[key])
logutil.BgLogger().Debug("[gc worker] save kv",
zap.String("key", key),
zap.String("value", value),
Expand Down
36 changes: 28 additions & 8 deletions util/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ package admin
import (
"context"
"encoding/json"
"fmt"
"math"
"sort"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/errno"
Expand Down Expand Up @@ -290,13 +290,13 @@ type RecordData struct {
Values []types.Datum
}

func getCount(ctx sessionctx.Context, sql string) (int64, error) {
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithSnapshot(sql)
func getCount(exec sqlexec.RestrictedSQLExecutor, stmt ast.StmtNode, snapshot uint64) (int64, error) {
rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt, sqlexec.ExecOptionWithSnapshot(snapshot))
if err != nil {
return 0, errors.Trace(err)
}
if len(rows) != 1 {
return 0, errors.Errorf("can not get count, sql %s result rows %d", sql, len(rows))
return 0, errors.Errorf("can not get count, rows count = %d", len(rows))
}
return rows[0].GetInt64(0), nil
}
Expand All @@ -317,14 +317,34 @@ func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices
// Here we need check all indexes, includes invisible index
ctx.GetSessionVars().OptimizerUseInvisibleIndexes = true
// Add `` for some names like `table name`.
sql := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX()", dbName, tableName)
tblCnt, err := getCount(ctx, sql)
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX()", dbName, tableName)
if err != nil {
return 0, 0, errors.Trace(err)
}

var snapshot uint64
txn, err := ctx.Txn(false)
if err != nil {
return 0, 0, err
}
if txn.Valid() {
snapshot = txn.StartTS()
}
if ctx.GetSessionVars().SnapshotTS != 0 {
snapshot = ctx.GetSessionVars().SnapshotTS
}

tblCnt, err := getCount(exec, stmt, snapshot)
if err != nil {
return 0, 0, errors.Trace(err)
}
for i, idx := range indices {
sql = fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX(`%s`)", dbName, tableName, idx)
idxCnt, err := getCount(ctx, sql)
stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX(%n)", dbName, tableName, idx)
if err != nil {
return 0, i, errors.Trace(err)
}
idxCnt, err := getCount(exec, stmt, snapshot)
if err != nil {
return 0, i, errors.Trace(err)
}
Expand Down
19 changes: 14 additions & 5 deletions util/gcutil/gcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package gcutil

import (
"context"
"fmt"

"github.com/pingcap/errors"
Expand All @@ -25,16 +26,20 @@ import (
)

const (
selectVariableValueSQL = `SELECT HIGH_PRIORITY variable_value FROM mysql.tidb WHERE variable_name='%s'`
selectVariableValueSQL = `SELECT HIGH_PRIORITY variable_value FROM mysql.tidb WHERE variable_name=%?`
insertVariableValueSQL = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES ('%[1]s', '%[2]s', '%[3]s')
ON DUPLICATE KEY
UPDATE variable_value = '%[2]s', comment = '%[3]s'`
)

// CheckGCEnable is use to check whether GC is enable.
func CheckGCEnable(ctx sessionctx.Context) (enable bool, err error) {
sql := fmt.Sprintf(selectVariableValueSQL, "tikv_gc_enable")
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_enable")
if err != nil {
return false, errors.Trace(err)
}
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt)
if err != nil {
return false, errors.Trace(err)
}
Expand Down Expand Up @@ -80,8 +85,12 @@ func ValidateSnapshotWithGCSafePoint(snapshotTS, safePointTS uint64) error {

// GetGCSafePoint loads GC safe point time from mysql.tidb.
func GetGCSafePoint(ctx sessionctx.Context) (uint64, error) {
sql := fmt.Sprintf(selectVariableValueSQL, "tikv_gc_safe_point")
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_safe_point")
if err != nil {
return 0, errors.Trace(err)
}
rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt)
if err != nil {
return 0, errors.Trace(err)
}
Expand Down
Loading

0 comments on commit 2e791d4

Please sign in to comment.