Skip to content

Commit

Permalink
*: adapt new api for the executor package (pingcap#22644)
Browse files Browse the repository at this point in the history
Signed-off-by: xhe <xw897002528@gmail.com>
  • Loading branch information
xhebox committed Mar 8, 2021
1 parent 31fd1f3 commit 697ca61
Show file tree
Hide file tree
Showing 21 changed files with 695 additions and 413 deletions.
2 changes: 1 addition & 1 deletion executor/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ package executor
import (
"bytes"
"context"
"fmt"
"math"
"math/rand"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down
1 change: 1 addition & 0 deletions executor/brie.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ func (gs *tidbGlueSession) CreateSession(store kv.Storage) (glue.Session, error)

// Execute implements glue.Session
func (gs *tidbGlueSession) Execute(ctx context.Context, sql string) error {
// FIXME: br relies on a deprecated API, it may be unsafe
_, err := gs.se.(sqlexec.SQLExecutor).Execute(ctx, sql)
return err
}
Expand Down
8 changes: 6 additions & 2 deletions executor/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,12 @@ func (e *DDLExec) dropTableObject(objects []*ast.TableName, obt objectType, ifEx
zap.String("database", fullti.Schema.O),
zap.String("table", fullti.Name.O),
)
sql := fmt.Sprintf("admin check table `%s`.`%s`", fullti.Schema.O, fullti.Name.O)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
exec := e.ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), "admin check table %n.%n", fullti.Schema.O, fullti.Name.O)
if err != nil {
return err
}
_, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return err
}
Expand Down
342 changes: 147 additions & 195 deletions executor/grant.go

Large diffs are not rendered by default.

16 changes: 14 additions & 2 deletions executor/infoschema_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,16 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex
}

func getRowCountAllTable(ctx sessionctx.Context) (map[int64]uint64, error) {
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL("select table_id, count from mysql.stats_meta")
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, count from mysql.stats_meta")
if err != nil {
return nil, err
}
rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return nil, err
}

rowCountMap := make(map[int64]uint64, len(rows))
for _, row := range rows {
tableID := row.GetInt64(0)
Expand All @@ -173,10 +179,16 @@ type tableHistID struct {
}

func getColLengthAllTables(ctx sessionctx.Context) (map[tableHistID]uint64, error) {
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL("select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0")
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0")
if err != nil {
return nil, err
}
rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return nil, err
}

colLengthMap := make(map[tableHistID]uint64, len(rows))
for _, row := range rows {
tableID := row.GetInt64(0)
Expand Down
8 changes: 7 additions & 1 deletion executor/inspection_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,13 @@ func (n *metricNode) getLabelValue(label string) *metricValue {
}

func (n *metricNode) queryRowsByLabel(pb *profileBuilder, query string, handleRowFn func(label string, v float64)) error {
rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(context.Background(), query)
exec := pb.sctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), query)
if err != nil {
return err
}

rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return err
}
Expand Down
154 changes: 113 additions & 41 deletions executor/inspection_result.go

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion executor/inspection_summary.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,12 @@ func (e *inspectionSummaryRetriever) retrieve(ctx context.Context, sctx sessionc
sql = fmt.Sprintf("select avg(value),min(value),max(value) from `%s`.`%s` %s",
util.MetricSchemaName.L, name, cond)
}
rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql)
exec := sctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(ctx, sql)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
rows, _, err := exec.ExecRestrictedStmt(ctx, stmt)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
Expand Down
16 changes: 13 additions & 3 deletions executor/metrics_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ type MetricsSummaryRetriever struct {
retrieved bool
}

func (e *MetricsSummaryRetriever) retrieve(_ context.Context, sctx sessionctx.Context) ([][]types.Datum, error) {
func (e *MetricsSummaryRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) {
if e.retrieved || e.extractor.SkipRequest {
return nil, nil
}
Expand Down Expand Up @@ -229,7 +229,12 @@ func (e *MetricsSummaryRetriever) retrieve(_ context.Context, sctx sessionctx.Co
name, util.MetricSchemaName.L, condition)
}

rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
exec := sctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(ctx, sql)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
rows, _, err := exec.ExecRestrictedStmt(ctx, stmt)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
Expand Down Expand Up @@ -306,7 +311,12 @@ func (e *MetricsSummaryByLabelRetriever) retrieve(ctx context.Context, sctx sess
sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value) from `%s`.`%s` %s",
util.MetricSchemaName.L, name, cond)
}
rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQLWithContext(ctx, sql)
exec := sctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(ctx, sql)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
rows, _, err := exec.ExecRestrictedStmt(ctx, stmt)
if err != nil {
return nil, errors.Errorf("execute '%s' failed: %v", sql, err)
}
Expand Down
8 changes: 6 additions & 2 deletions executor/opt_rule_blacklist.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ func (e *ReloadOptRuleBlacklistExec) Next(ctx context.Context, _ *chunk.Chunk) e

// LoadOptRuleBlacklist loads the latest data from table mysql.opt_rule_blacklist.
func LoadOptRuleBlacklist(ctx sessionctx.Context) (err error) {
sql := "select HIGH_PRIORITY name from mysql.opt_rule_blacklist"
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name from mysql.opt_rule_blacklist")
if err != nil {
return err
}
rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error {
err error
)
if sqlParser, ok := e.ctx.(sqlexec.SQLParser); ok {
// FIXME: ok... yet another parse API, may need some api interface clean.
stmts, err = sqlParser.ParseSQL(e.sqlText, charset, collation)
} else {
p := parser.New()
Expand Down
8 changes: 6 additions & 2 deletions executor/reload_expr_pushdown_blacklist.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ func (e *ReloadExprPushdownBlacklistExec) Next(ctx context.Context, _ *chunk.Chu

// LoadExprPushdownBlacklist loads the latest data from table mysql.expr_pushdown_blacklist.
func LoadExprPushdownBlacklist(ctx sessionctx.Context) (err error) {
sql := "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist"
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql)
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist")
if err != nil {
return err
}
rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return err
}
Expand Down
103 changes: 87 additions & 16 deletions executor/revoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ package executor

import (
"context"
"fmt"
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
Expand Down Expand Up @@ -73,15 +73,15 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error {
}
defer func() {
if !isCommit {
_, err := internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback")
_, err := internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "rollback")
if err != nil {
logutil.BgLogger().Error("rollback error occur at grant privilege", zap.Error(err))
}
}
e.releaseSysSession(internalSession)
}()

_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "begin")
_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "begin")
if err != nil {
return err
}
Expand All @@ -103,7 +103,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error {
}
}

_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "commit")
_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), "commit")
if err != nil {
return err
}
Expand Down Expand Up @@ -166,12 +166,15 @@ func (e *RevokeExec) revokePriv(internalSession sessionctx.Context, priv *ast.Pr
}

func (e *RevokeExec) revokeGlobalPriv(internalSession sessionctx.Context, priv *ast.PrivElem, user, host string) error {
asgns, err := composeGlobalPrivUpdate(priv.Priv, "N")
sql := new(strings.Builder)
sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.UserTable)
err := composeGlobalPrivUpdate(sql, priv.Priv, "N")
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.UserTable, asgns, user, host)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%?", user, host)

_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String())
return err
}

Expand All @@ -180,12 +183,16 @@ func (e *RevokeExec) revokeDBPriv(internalSession sessionctx.Context, priv *ast.
if len(dbName) == 0 {
dbName = e.ctx.GetSessionVars().CurrentDB
}
asgns, err := composeDBPrivUpdate(priv.Priv, "N")

sql := new(strings.Builder)
sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.DBTable)
err := composeDBPrivUpdate(sql, priv.Priv, "N")
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, asgns, userName, host, dbName)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%?", userName, host, dbName)

_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String())
return err
}

Expand All @@ -194,12 +201,16 @@ func (e *RevokeExec) revokeTablePriv(internalSession sessionctx.Context, priv *a
if err != nil {
return err
}
asgns, err := composeTablePrivUpdateForRevoke(internalSession, priv.Priv, user, host, dbName, tbl.Meta().Name.O)

sql := new(strings.Builder)
sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.TablePrivTable)
err = composeTablePrivUpdateForRevoke(internalSession, sql, priv.Priv, user, host, dbName, tbl.Meta().Name.O)
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, asgns, user, host, dbName, tbl.Meta().Name.O)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%?", user, host, dbName, tbl.Meta().Name.O)

_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String())
return err
}

Expand All @@ -208,20 +219,80 @@ func (e *RevokeExec) revokeColumnPriv(internalSession sessionctx.Context, priv *
if err != nil {
return err
}
sql := new(strings.Builder)
for _, c := range priv.Cols {
col := table.FindCol(tbl.Cols(), c.Name.L)
if col == nil {
return errors.Errorf("Unknown column: %s", c)
}
asgns, err := composeColumnPrivUpdateForRevoke(internalSession, priv.Priv, user, host, dbName, tbl.Meta().Name.O, col.Name.O)

sql.Reset()
sqlexec.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.ColumnPrivTable)
err = composeColumnPrivUpdateForRevoke(internalSession, sql, priv.Priv, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
if err != nil {
return err
}
sqlexec.MustFormatSQL(sql, " WHERE User=%? AND Host=%? AND DB=%? AND Table_name=%? AND Column_name=%?", user, host, dbName, tbl.Meta().Name.O, col.Name.O)

_, err = internalSession.(sqlexec.SQLExecutor).ExecuteInternal(context.Background(), sql.String())
if err != nil {
return err
}
}
return nil
}

func privUpdateForRevoke(cur []string, priv mysql.PrivilegeType) ([]string, error) {
p, ok := mysql.Priv2SetStr[priv]
if !ok {
return nil, errors.Errorf("Unknown priv: %v", priv)
}
cur = deleteFromSet(cur, p)
return cur, nil
}

func composeTablePrivUpdateForRevoke(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string) error {
var newTablePriv, newColumnPriv []string

if priv != mysql.AllPriv {
currTablePriv, currColumnPriv, err := getTablePriv(ctx, name, host, db, tbl)
if err != nil {
return err
}

newTablePriv = setFromString(currTablePriv)
newTablePriv, err = privUpdateForRevoke(newTablePriv, priv)
if err != nil {
return err
}

newColumnPriv = setFromString(currColumnPriv)
newColumnPriv, err = privUpdateForRevoke(newColumnPriv, priv)
if err != nil {
return err
}
}

sqlexec.MustFormatSQL(sql, `Table_priv=%?, Column_priv=%?, Grantor=%?`, strings.Join(newTablePriv, ","), strings.Join(newColumnPriv, ","), ctx.GetSessionVars().User.String())
return nil
}

func composeColumnPrivUpdateForRevoke(ctx sessionctx.Context, sql *strings.Builder, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) error {
var newColumnPriv []string

if priv != mysql.AllPriv {
currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col)
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, asgns, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)

newColumnPriv = setFromString(currColumnPriv)
newColumnPriv, err = privUpdateForRevoke(newColumnPriv, priv)
if err != nil {
return err
}
}

sqlexec.MustFormatSQL(sql, `Column_priv=%?`, strings.Join(newColumnPriv, ","))
return nil
}
Loading

0 comments on commit 697ca61

Please sign in to comment.