Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

privilege, executor: add SET ROLE and CURRENT_ROLE support #9581

Merged
merged 15 commits into from
Mar 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions executor/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ var (
ErrTableaccessDenied = terror.ClassExecutor.New(mysql.ErrTableaccessDenied, mysql.MySQLErrName[mysql.ErrTableaccessDenied])
ErrBadDB = terror.ClassExecutor.New(mysql.ErrBadDB, mysql.MySQLErrName[mysql.ErrBadDB])
ErrWrongObject = terror.ClassExecutor.New(mysql.ErrWrongObject, mysql.MySQLErrName[mysql.ErrWrongObject])
ErrRoleNotGranted = terror.ClassPrivilege.New(mysql.ErrRoleNotGranted, mysql.MySQLErrName[mysql.ErrRoleNotGranted])
)

func init() {
Expand Down
23 changes: 23 additions & 0 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,34 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err erro
return nil
case *ast.DropStatsStmt:
err = e.executeDropStats(x)
case *ast.SetRoleStmt:
err = e.executeSetRole(x)
}
e.done = true
return errors.Trace(err)
}

func (e *SimpleExec) executeSetRole(s *ast.SetRoleStmt) error {
checkDup := make(map[string]*auth.RoleIdentity, len(s.RoleList))
// Check whether RoleNameList contain duplicate role name.
for _, r := range s.RoleList {
key := r.String()
checkDup[key] = r
}
roleList := make([]*auth.RoleIdentity, 0, 10)
for _, v := range checkDup {
roleList = append(roleList, v)
}

checker := privilege.GetPrivilegeManager(e.ctx)
ok, roleName := checker.ActiveRoles(e.ctx, roleList)
if !ok {
u := e.ctx.GetSessionVars().User
return ErrRoleNotGranted.GenWithStackByArgs(roleName, u.String())
}
return nil
}

func (e *SimpleExec) dbAccessDenied(dbname string) error {
user := e.ctx.GetSessionVars().User
u := user.Username
Expand Down
1 change: 1 addition & 0 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ var funcs = map[string]functionClass{
// information functions
ast.ConnectionID: &connectionIDFunctionClass{baseFunctionClass{ast.ConnectionID, 0, 0}},
ast.CurrentUser: &currentUserFunctionClass{baseFunctionClass{ast.CurrentUser, 0, 0}},
ast.CurrentRole: &currentRoleFunctionClass{baseFunctionClass{ast.CurrentRole, 0, 0}},
ast.Database: &databaseFunctionClass{baseFunctionClass{ast.Database, 0, 0}},
// This function is a synonym for DATABASE().
// See http://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_schema
Expand Down
45 changes: 45 additions & 0 deletions expression/builtin_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var (
_ functionClass = &databaseFunctionClass{}
_ functionClass = &foundRowsFunctionClass{}
_ functionClass = &currentUserFunctionClass{}
_ functionClass = &currentRoleFunctionClass{}
_ functionClass = &userFunctionClass{}
_ functionClass = &connectionIDFunctionClass{}
_ functionClass = &lastInsertIDFunctionClass{}
Expand Down Expand Up @@ -156,6 +157,50 @@ func (b *builtinCurrentUserSig) evalString(row chunk.Row) (string, bool, error)
return data.User.AuthIdentityString(), false, nil
}

type currentRoleFunctionClass struct {
baseFunctionClass
}

func (c *currentRoleFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString)
bf.tp.Flen = 64
sig := &builtinCurrentRoleSig{bf}
return sig, nil
}

type builtinCurrentRoleSig struct {
baseBuiltinFunc
}

func (b *builtinCurrentRoleSig) Clone() builtinFunc {
newSig := &builtinCurrentRoleSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalString evals a builtinCurrentUserSig.
// See https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user
func (b *builtinCurrentRoleSig) evalString(row chunk.Row) (string, bool, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to add some test cases for this built-in function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

data := b.ctx.GetSessionVars()
if data == nil || data.ActiveRoles == nil {
return "", true, errors.Errorf("Missing session variable when eval builtin")
}
if len(data.ActiveRoles) == 0 {
return "", false, nil
}
res := ""
for i, r := range data.ActiveRoles {
res += r.String()
if i != len(data.ActiveRoles)-1 {
res += ","
}
}
return res, false, nil
}

type userFunctionClass struct {
baseFunctionClass
}
Expand Down
16 changes: 16 additions & 0 deletions expression/builtin_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ func (s *testEvaluatorSuite) TestCurrentUser(c *C) {
c.Assert(d.GetString(), Equals, "root@localhost")
}

func (s *testEvaluatorSuite) TestCurrentRole(c *C) {
defer testleak.AfterTest(c)()
ctx := mock.NewContext()
sessionVars := ctx.GetSessionVars()
sessionVars.ActiveRoles = make([]*auth.RoleIdentity, 0, 10)
sessionVars.ActiveRoles = append(sessionVars.ActiveRoles, &auth.RoleIdentity{Username: "r_1", Hostname: "%"})
sessionVars.ActiveRoles = append(sessionVars.ActiveRoles, &auth.RoleIdentity{Username: "r_2", Hostname: "localhost"})

fc := funcs[ast.CurrentRole]
f, err := fc.getFunction(ctx, nil)
c.Assert(err, IsNil)
d, err := evalBuiltinFunc(f, chunk.Row{})
c.Assert(err, IsNil)
c.Assert(d.GetString(), Equals, "`r_1`@`%`,`r_2`@`localhost`")
}

func (s *testEvaluatorSuite) TestConnectionID(c *C) {
defer testleak.AfterTest(c)()
ctx := mock.NewContext()
Expand Down
1 change: 1 addition & 0 deletions expression/function_traits.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
var UnCacheableFunctions = map[string]struct{}{
ast.Database: {},
ast.CurrentUser: {},
ast.CurrentRole: {},
ast.User: {},
ast.ConnectionID: {},
ast.LastInsertId: {},
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ require (
github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e
github.com/pingcap/kvproto v0.0.0-20190215154024-7f2fc73ef562
github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596
github.com/pingcap/parser v0.0.0-20190312024907-3f6280b08c8b
github.com/pingcap/parser v0.0.0-20190320053247-fe243e3280cf
github.com/pingcap/pd v2.1.0-rc.4+incompatible
github.com/pingcap/tidb-tools v2.1.3-0.20190116051332-34c808eef588+incompatible
github.com/pingcap/tipb v0.0.0-20190107072121-abbec73437b7
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ github.com/pingcap/kvproto v0.0.0-20190215154024-7f2fc73ef562 h1:32oF1/8lVnBR2JV
github.com/pingcap/kvproto v0.0.0-20190215154024-7f2fc73ef562/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY=
github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596 h1:t2OQTpPJnrPDGlvA+3FwJptMTt6MEPdzK1Wt99oaefQ=
github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw=
github.com/pingcap/parser v0.0.0-20190312024907-3f6280b08c8b h1:NlvTrxqezIJh6CD5Leky12IZ8E/GtpEEmzgNNb34wbw=
github.com/pingcap/parser v0.0.0-20190312024907-3f6280b08c8b/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/parser v0.0.0-20190320053247-fe243e3280cf h1:yxK78TmeSK3BIm8Z8SwdZLVzRpY80HZe1VMlA2dL648=
github.com/pingcap/parser v0.0.0-20190320053247-fe243e3280cf/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/pd v2.1.0-rc.4+incompatible h1:/buwGk04aHO5odk/+O8ZOXGs4qkUjYTJ2UpCJXna8NE=
github.com/pingcap/pd v2.1.0-rc.4+incompatible/go.mod h1:nD3+EoYes4+aNNODO99ES59V83MZSI+dFbhyr667a0E=
github.com/pingcap/tidb-tools v2.1.3-0.20190116051332-34c808eef588+incompatible h1:e9Gi/LP9181HT3gBfSOeSBA+5JfemuE4aEAhqNgoE4k=
Expand Down
4 changes: 2 additions & 2 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ func (b *PlanBuilder) Build(node ast.Node) (Plan, error) {
case *ast.AnalyzeTableStmt:
return b.buildAnalyze(x)
case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt,
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt,
*ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt:
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.GrantStmt,
*ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt, *ast.SetRoleStmt:
return b.buildSimple(node.(ast.StmtNode))
case ast.DDLNode:
return b.buildDDL(x)
Expand Down
4 changes: 4 additions & 0 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ type Manager interface {

// UserPrivilegesTable provide data for INFORMATION_SCHEMA.USERS_PRIVILEGE table.
UserPrivilegesTable() [][]types.Datum

// ActiveRoles active roles for current session.
// The first illegal role will be returned.
ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string)
}

const key keyType = 0
Expand Down
74 changes: 74 additions & 0 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/sessionctx"
Expand Down Expand Up @@ -104,12 +105,42 @@ type columnsPrivRecord struct {
patTypes []byte
}

// RoleGraphEdgesTable is used to cache relationship between and role.
type roleGraphEdgesTable struct {
roleList map[string]bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to use roleList map[string]struct{} here

roleList map[string]struct{}

roleList["xx"] = struct{}{}

if _, ok := roleList["xx"] {
    ...
}

}

// Find method is used to find role from table
func (g roleGraphEdgesTable) Find(user, host string) bool {
if host == "" {
host = "%"
}
key := user + "@" + host
if g.roleList == nil {
return false
}
_, ok := g.roleList[key]
return ok
}

// MySQLPrivilege is the in-memory cache of mysql privilege tables.
type MySQLPrivilege struct {
User []UserRecord
DB []dbRecord
TablesPriv []tablesPrivRecord
ColumnsPriv []columnsPrivRecord
RoleGraph map[string]roleGraphEdgesTable
}

// FindRole is used to detect whether there is edges between users and roles.
func (p *MySQLPrivilege) FindRole(user string, host string, role *auth.RoleIdentity) bool {
rec := p.matchUser(user, host)
r := p.matchUser(role.Username, role.Hostname)
if rec != nil && r != nil {
key := rec.User + "@" + rec.Host
return p.RoleGraph[key].Find(role.Username, role.Hostname)
}
return false
}

// LoadAll loads the tables from database to memory.
Expand Down Expand Up @@ -142,6 +173,14 @@ func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error {
}
log.Warn("mysql.columns_priv missing")
}

err = p.LoadRoleGraph(ctx)
if err != nil {
if !noSuchTable(err) {
return errors.Trace(err)
}
log.Warn("mysql.role_edges missing")
}
return nil
}

Expand All @@ -155,6 +194,16 @@ func noSuchTable(err error) bool {
return false
}

// LoadRoleGraph loads the mysql.role_edges table from database.
func (p *MySQLPrivilege) LoadRoleGraph(ctx sessionctx.Context) error {
p.RoleGraph = make(map[string]roleGraphEdgesTable)
err := p.loadTable(ctx, "select FROM_USER, FROM_HOST, TO_USER, TO_HOST from mysql.role_edges;", p.decodeRoleEdgesTable)
if err != nil {
return errors.Trace(err)
}
return nil
}

// LoadUserTable loads the mysql.user table from database.
func (p *MySQLPrivilege) LoadUserTable(ctx sessionctx.Context) error {
err := p.loadTable(ctx, "select HIGH_PRIORITY Host,User,Password,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Process_priv,Grant_priv,References_priv,Alter_priv,Show_db_priv,Super_priv,Execute_priv,Create_view_priv,Show_view_priv,Index_priv,Create_user_priv,Trigger_priv,Create_role_priv,Drop_role_priv,account_locked from mysql.user;", p.decodeUserTableRow)
Expand Down Expand Up @@ -381,6 +430,31 @@ func (p *MySQLPrivilege) decodeTablesPrivTableRow(row chunk.Row, fs []*ast.Resul
return nil
}

func (p *MySQLPrivilege) decodeRoleEdgesTable(row chunk.Row, fs []*ast.ResultField) error {
var fromUser, fromHost, toHost, toUser string
for i, f := range fs {
switch {
case f.ColumnAsName.L == "from_host":
fromHost = row.GetString(i)
case f.ColumnAsName.L == "from_user":
fromUser = row.GetString(i)
case f.ColumnAsName.L == "to_host":
toHost = row.GetString(i)
case f.ColumnAsName.L == "to_user":
toUser = row.GetString(i)
}
}
fromKey := fromUser + "@" + fromHost
toKey := toUser + "@" + toHost
roleGraph, ok := p.RoleGraph[toKey]
if !ok {
roleGraph = roleGraphEdgesTable{roleList: make(map[string]bool)}
p.RoleGraph[toKey] = roleGraph
}
roleGraph.roleList[fromKey] = true
return nil
}

func (p *MySQLPrivilege) decodeColumnsPrivTableRow(row chunk.Row, fs []*ast.ResultField) error {
var value columnsPrivRecord
for i, f := range fs {
Expand Down
30 changes: 30 additions & 0 deletions privilege/privileges/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,36 @@ func (s *testCacheSuite) TestCaseInsensitive(c *C) {
c.Assert(p.RequestVerification("genius", "127.0.0.1", "tctrain", "tctrainorder", "", mysql.SelectPriv), IsTrue)
}

func (s *testCacheSuite) TestLoadRoleGraph(c *C) {
se, err := session.CreateSession4Test(s.store)
c.Assert(err, IsNil)
defer se.Close()
mustExec(c, se, "use mysql;")
mustExec(c, se, "truncate table user;")

var p privileges.MySQLPrivilege
err = p.LoadRoleGraph(se)
c.Assert(err, IsNil)
c.Assert(len(p.User), Equals, 0)

mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_1", "%", "user2")`)
mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_2", "%", "root")`)
mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_3", "%", "user1")`)
mustExec(c, se, `INSERT INTO mysql.role_edges (FROM_HOST, FROM_USER, TO_HOST, TO_USER) VALUES ("%", "r_4", "%", "root")`)

p = privileges.MySQLPrivilege{}
err = p.LoadRoleGraph(se)
c.Assert(err, IsNil)
graph := p.RoleGraph
c.Assert(graph["root@%"].Find("r_2", "%"), Equals, true)
c.Assert(graph["root@%"].Find("r_4", "%"), Equals, true)
c.Assert(graph["user2@%"].Find("r_1", "%"), Equals, true)
c.Assert(graph["user1@%"].Find("r_3", "%"), Equals, true)
_, ok := graph["illedal"]
c.Assert(ok, Equals, false)
c.Assert(graph["root@%"].Find("r_1", "%"), Equals, false)
}

func (s *testCacheSuite) TestAbnormalMySQLTable(c *C) {
store, err := mockstore.NewMockTikvStore()
c.Assert(err, IsNil)
Expand Down
16 changes: 16 additions & 0 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,19 @@ func (p *UserPrivileges) ShowGrants(ctx sessionctx.Context, user *auth.UserIdent

return
}

// ActiveRoles implements privilege.Manager ActiveRoles interface.
func (p *UserPrivileges) ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string) {
mysqlPrivilege := p.Handle.Get()
u := p.user
h := p.host
for _, r := range roleList {
ok := mysqlPrivilege.FindRole(u, h, r)
if !ok {
log.Errorf("Role: %+v doesn't grant for user", r)
return false, r.String()
}
}
ctx.GetSessionVars().ActiveRoles = roleList
return true, ""
}
4 changes: 4 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ type SessionVars struct {
// params for prepared statements
PreparedParams []types.Datum

// ActiveRoles stores active roles for current user
ActiveRoles []*auth.RoleIdentity

// retry information
RetryInfo *RetryInfo
// Should be reset on transaction finished.
Expand Down Expand Up @@ -350,6 +353,7 @@ func NewSessionVars() *SessionVars {
TxnCtx: &TransactionContext{},
KVVars: kv.NewVariables(),
RetryInfo: &RetryInfo{},
ActiveRoles: make([]*auth.RoleIdentity, 0, 10),
StrictSQLMode: true,
Status: mysql.ServerStatusAutocommit,
StmtCtx: new(stmtctx.StatementContext),
Expand Down