Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

privilege: fix atomic problem of GRANT and REVOKE (#14219) #15092

Merged
merged 2 commits into from
Mar 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 133 additions & 45 deletions executor/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import (
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/sqlexec"
"go.uber.org/zap"
)

/***
Expand Down Expand Up @@ -90,9 +92,35 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error {
}
}

// Grant for each user
for idx, user := range e.Users {
// Check if user exists.
// Commit the old transaction, like DDL.
if err := e.ctx.NewTxn(ctx); err != nil {
return err
}
defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }()

// Create internal session to start internal transaction.
isCommit := false
internalSession, err := e.getSysSession()
if err != nil {
return err
}
defer func() {
if !isCommit {
_, err := internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback")
if err != nil {
logutil.Logger(context.Background()).Error("rollback error occur at grant privilege", zap.Error(err))
}
}
e.releaseSysSession(internalSession)
}()

_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "begin")
if err != nil {
return err
}

// Check which user is not exist.
for _, user := range e.Users {
exists, err := userExists(e.ctx, user.User.Username, user.User.Hostname)
if err != nil {
return err
Expand All @@ -106,31 +134,34 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error {
}
user := fmt.Sprintf(`('%s', '%s', '%s')`, user.User.Hostname, user.User.Username, pwd)
sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password) VALUES %s;`, mysql.SystemDB, mysql.UserTable, user)
_, err := e.ctx.(sqlexec.SQLExecutor).Execute(ctx, sql)
_, err := internalSession.(sqlexec.SQLExecutor).Execute(ctx, sql)
if err != nil {
return err
}
}
}

// Grant for each user
for _, user := range e.Users {
// If there is no privilege entry in corresponding table, insert a new one.
// Global scope: mysql.global_priv
// DB scope: mysql.DB
// Table scope: mysql.Tables_priv
// Column scope: mysql.Columns_priv
if e.TLSOptions != nil {
err = checkAndInitGlobalPriv(e.ctx, user.User.Username, user.User.Hostname)
err = checkAndInitGlobalPriv(internalSession, user.User.Username, user.User.Hostname)
if err != nil {
return err
}
}
switch e.Level.Level {
case ast.GrantLevelDB:
err := checkAndInitDBPriv(e.ctx, dbName, e.is, user.User.Username, user.User.Hostname)
err := checkAndInitDBPriv(internalSession, dbName, e.is, user.User.Username, user.User.Hostname)
if err != nil {
return err
}
case ast.GrantLevelTable:
err := checkAndInitTablePriv(e.ctx, dbName, e.Level.TableName, e.is, user.User.Username, user.User.Hostname)
err := checkAndInitTablePriv(internalSession, dbName, e.Level.TableName, e.is, user.User.Username, user.User.Hostname)
if err != nil {
return err
}
Expand All @@ -140,15 +171,8 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error {
privs = append(privs, &ast.PrivElem{Priv: mysql.GrantPriv})
}

if idx == 0 {
// Commit the old transaction, like DDL.
if err := e.ctx.NewTxn(ctx); err != nil {
return err
}
defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }()
}
// Grant global priv to user.
err = e.grantGlobalPriv(user)
err = e.grantGlobalPriv(internalSession, user)
if err != nil {
return err
}
Expand All @@ -157,17 +181,23 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error {
if len(priv.Cols) > 0 {
// Check column scope privilege entry.
// TODO: Check validity before insert new entry.
err := e.checkAndInitColumnPriv(user.User.Username, user.User.Hostname, priv.Cols)
err := e.checkAndInitColumnPriv(user.User.Username, user.User.Hostname, priv.Cols, internalSession)
if err != nil {
return err
}
}
err := e.grantLevelPriv(priv, user)
err := e.grantLevelPriv(priv, user, internalSession)
if err != nil {
return err
}
}
}

_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), "commit")
if err != nil {
return err
}
isCommit = true
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return nil
}
Expand Down Expand Up @@ -216,7 +246,7 @@ func checkAndInitTablePriv(ctx sessionctx.Context, dbName, tblName string, is in

// checkAndInitColumnPriv checks if column scope privilege entry exists in mysql.Columns_priv.
// If unexists, insert a new one.
func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast.ColumnName) error {
func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast.ColumnName, internalSession sessionctx.Context) error {
dbName, tbl, err := getTargetSchemaAndTable(e.ctx, e.Level.DBName, e.Level.TableName, e.is)
if err != nil {
return err
Expand All @@ -226,15 +256,15 @@ func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast
if col == nil {
return errors.Errorf("Unknown column: %s", c.Name.O)
}
ok, err := columnPrivEntryExists(e.ctx, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
ok, err := columnPrivEntryExists(internalSession, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
if err != nil {
return err
}
if ok {
continue
}
// Entry does not exist for user-host-db-tbl-col. Insert a new entry.
err = initColumnPrivEntry(e.ctx, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
err = initColumnPrivEntry(internalSession, user, host, dbName, tbl.Meta().Name.O, col.Name.O)
if err != nil {
return err
}
Expand All @@ -245,33 +275,33 @@ func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast
// initGlobalPrivEntry inserts a new row into mysql.DB with empty privilege.
func initGlobalPrivEntry(ctx sessionctx.Context, user string, host string) error {
sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, PRIV) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.GlobalPrivTable, host, user, "{}")
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
_, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// initDBPrivEntry inserts a new row into mysql.DB with empty privilege.
func initDBPrivEntry(ctx sessionctx.Context, user string, host string, db string) error {
sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.DBTable, host, user, db)
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
_, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// initTablePrivEntry inserts a new row into mysql.Tables_priv with empty privilege.
func initTablePrivEntry(ctx sessionctx.Context, user string, host string, db string, tbl string) error {
sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB, Table_name, Table_priv, Column_priv) VALUES ('%s', '%s', '%s', '%s', '', '')`, mysql.SystemDB, mysql.TablePrivTable, host, user, db, tbl)
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
_, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// initColumnPrivEntry inserts a new row into mysql.Columns_priv with empty privilege.
func initColumnPrivEntry(ctx sessionctx.Context, user string, host string, db string, tbl string, col string) error {
sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB, Table_name, Column_name, Column_priv) VALUES ('%s', '%s', '%s', '%s', '%s', '')`, mysql.SystemDB, mysql.ColumnPrivTable, host, user, db, tbl, col)
_, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
_, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// grantGlobalPriv grants priv to user in global scope.
func (e *GrantExec) grantGlobalPriv(user *ast.UserSpec) error {
func (e *GrantExec) grantGlobalPriv(ctx sessionctx.Context, user *ast.UserSpec) error {
if len(e.TLSOptions) == 0 {
return nil
}
Expand All @@ -280,7 +310,7 @@ func (e *GrantExec) grantGlobalPriv(user *ast.UserSpec) error {
return errors.Trace(err)
}
sql := fmt.Sprintf(`UPDATE %s.%s SET PRIV = '%s' WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.GlobalPrivTable, priv, user.User.Username, user.User.Hostname)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
_, err = ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

Expand Down Expand Up @@ -356,24 +386,24 @@ func tlsOption2GlobalPriv(tlsOptions []*ast.TLSOption) (priv []byte, err error)
}

// grantLevelPriv grants priv to user in s.Level scope.
func (e *GrantExec) grantLevelPriv(priv *ast.PrivElem, user *ast.UserSpec) error {
func (e *GrantExec) grantLevelPriv(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error {
switch e.Level.Level {
case ast.GrantLevelGlobal:
return e.grantGlobalLevel(priv, user)
return e.grantGlobalLevel(priv, user, internalSession)
case ast.GrantLevelDB:
return e.grantDBLevel(priv, user)
return e.grantDBLevel(priv, user, internalSession)
case ast.GrantLevelTable:
if len(priv.Cols) == 0 {
return e.grantTableLevel(priv, user)
return e.grantTableLevel(priv, user, internalSession)
}
return e.grantColumnLevel(priv, user)
return e.grantColumnLevel(priv, user, internalSession)
default:
return errors.Errorf("Unknown grant level: %#v", e.Level)
}
}

// grantGlobalLevel manipulates mysql.user table.
func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec) error {
func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error {
if priv.Priv == 0 {
return nil
}
Expand All @@ -382,12 +412,12 @@ func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec) err
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.UserTable, asgns, user.User.Username, user.User.Hostname)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// grantDBLevel manipulates mysql.db table.
func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec) error {
func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error {
dbName := e.Level.DBName
if len(dbName) == 0 {
dbName = e.ctx.GetSessionVars().CurrentDB
Expand All @@ -397,28 +427,28 @@ func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec) error {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, asgns, user.User.Username, user.User.Hostname, dbName)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// grantTableLevel manipulates mysql.tables_priv table.
func (e *GrantExec) grantTableLevel(priv *ast.PrivElem, user *ast.UserSpec) error {
func (e *GrantExec) grantTableLevel(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error {
dbName := e.Level.DBName
if len(dbName) == 0 {
dbName = e.ctx.GetSessionVars().CurrentDB
}
tblName := e.Level.TableName
asgns, err := composeTablePrivUpdateForGrant(e.ctx, priv.Priv, user.User.Username, user.User.Hostname, dbName, tblName)
asgns, err := composeTablePrivUpdateForGrant(internalSession, priv.Priv, user.User.Username, user.User.Hostname, dbName, tblName)
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, asgns, user.User.Username, user.User.Hostname, dbName, tblName)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
return err
}

// grantColumnLevel manipulates mysql.tables_priv table.
func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec) error {
func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec, internalSession sessionctx.Context) error {
dbName, tbl, err := getTargetSchemaAndTable(e.ctx, e.Level.DBName, e.Level.TableName, e.is)
if err != nil {
return err
Expand All @@ -429,12 +459,12 @@ func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec) err
if col == nil {
return errors.Errorf("Unknown column: %s", c)
}
asgns, err := composeColumnPrivUpdateForGrant(e.ctx, priv.Priv, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O)
asgns, err := composeColumnPrivUpdateForGrant(internalSession, priv.Priv, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O)
if err != nil {
return err
}
sql := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, asgns, user.User.Username, user.User.Hostname, dbName, tbl.Meta().Name.O, col.Name.O)
_, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
_, err = internalSession.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
if err != nil {
return err
}
Expand Down Expand Up @@ -610,7 +640,11 @@ func composeColumnPrivUpdateForRevoke(ctx sessionctx.Context, priv mysql.Privile

// recordExists is a helper function to check if the sql returns any row.
func recordExists(ctx sessionctx.Context, sql string) (bool, error) {
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
recordSets, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
if err != nil {
return false, err
}
rows, _, err := getRowsAndFields(ctx, recordSets)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -645,14 +679,21 @@ func columnPrivEntryExists(ctx sessionctx.Context, name string, host string, db
// Return Table_priv and Column_priv.
func getTablePriv(ctx sessionctx.Context, name string, host string, db string, tbl string) (string, string, error) {
sql := fmt.Sprintf(`SELECT Table_priv, Column_priv FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s';`, mysql.SystemDB, mysql.TablePrivTable, name, host, db, tbl)
rows, fields, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
rs, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
if err != nil {
return "", "", err
}
if len(rows) < 1 {
if len(rs) < 1 {
return "", "", errors.Errorf("get table privilege fail for %s %s %s %s", name, host, db, tbl)
}
var tPriv, cPriv string
rows, fields, err := getRowsAndFields(ctx, rs)
if err != nil {
return "", "", err
}
if len(rows) < 1 {
return "", "", errors.Errorf("get table privilege fail for %s %s %s %s", name, host, db, tbl)
}
row := rows[0]
if fields[0].Column.Tp == mysql.TypeSet {
tablePriv := row.GetSet(0)
Expand All @@ -669,7 +710,14 @@ func getTablePriv(ctx sessionctx.Context, name string, host string, db string, t
// Return Column_priv.
func getColumnPriv(ctx sessionctx.Context, name string, host string, db string, tbl string, col string) (string, error) {
sql := fmt.Sprintf(`SELECT Column_priv FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s' AND Table_name='%s' AND Column_name='%s';`, mysql.SystemDB, mysql.ColumnPrivTable, name, host, db, tbl, col)
rows, fields, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
rs, err := ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql)
if err != nil {
return "", err
}
if len(rs) < 1 {
return "", errors.Errorf("get column privilege fail for %s %s %s %s", name, host, db, tbl)
}
rows, fields, err := getRowsAndFields(ctx, rs)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -699,3 +747,43 @@ func getTargetSchemaAndTable(ctx sessionctx.Context, dbName, tableName string, i
}
return dbName, tbl, nil
}

// getRowsAndFields is used to extract rows from record sets.
func getRowsAndFields(ctx sessionctx.Context, recordSets []sqlexec.RecordSet) ([]chunk.Row, []*ast.ResultField, error) {
var (
rows []chunk.Row
fields []*ast.ResultField
)

for i, rs := range recordSets {
tmp, err := getRowFromRecordSet(context.Background(), ctx, rs)
if err != nil {
return nil, nil, err
}
if err = rs.Close(); err != nil {
return nil, nil, err
}

if i == 0 {
rows = tmp
fields = rs.Fields()
}
}
return rows, fields, nil
}

func getRowFromRecordSet(ctx context.Context, se sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, error) {
var rows []chunk.Row
req := rs.NewChunk()
for {
err := rs.Next(ctx, req)
if err != nil || req.NumRows() == 0 {
return rows, err
}
iter := chunk.NewIterator4Chunk(req)
for r := iter.Begin(); r != iter.End(); r = iter.Next() {
rows = append(rows, r)
}
req = chunk.Renew(req, se.GetSessionVars().MaxChunkSize)
}
}
Loading