Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: add privilege check for slow_query and cluster memory table #14451

Merged
merged 21 commits into from
Feb 12, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
*: add privelege check for slow_query and cluster memory table
  • Loading branch information
crazycs520 committed Jan 11, 2020
commit 1733ab8597d82700bd35f539ee35f6922826ac69
6 changes: 6 additions & 0 deletions distsql/distsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fie
if !sctx.GetSessionVars().EnableStreaming {
kvReq.Streaming = false
}
if kvReq.StoreType == kv.TiDB && sctx.GetSessionVars().User != nil {
kvReq.User = &kv.UserIdentity{
Username: sctx.GetSessionVars().User.Username,
Hostname: sctx.GetSessionVars().User.Hostname,
}
}
resp := sctx.GetClient().Send(ctx, kvReq, sctx.GetSessionVars().KVVars)
if resp == nil {
err := errors.New("client returns nil response")
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,5 @@ require (
)

go 1.13

replace github.com/pingcap/kvproto => github.com/crazycs520/kvproto v0.0.0-20200111035535-c9a393af6414
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ github.com/coreos/go-systemd v0.0.0-20181031085051-9002847aa142/go.mod h1:F5haX7
github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f h1:lBNOc5arjvs8E5mO2tbpBpLoyyu8B6e44T7hJy6potg=
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
github.com/crazycs520/kvproto v0.0.0-20200111035535-c9a393af6414 h1:ylvuO7+y5tAG5GvzKMiDS3wqe95TQqH2hcqC/WV1NC4=
github.com/crazycs520/kvproto v0.0.0-20200111035535-c9a393af6414/go.mod h1:WWLmULLO7l8IOcQG+t+ItJ3fEcrL5FxF0Wu+HrMy26w=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/cznic/golex v0.0.0-20181122101858-9c343928389c/go.mod h1:+bmmJDNmKlhWNG+gwWCkaBoTy39Fs+bzRxVBzoTQbIc=
github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 h1:iwZdTE0PVqJCos1vaoKsclOGD3ADKpshg3SRtYBbwso=
Expand Down
54 changes: 41 additions & 13 deletions infoschema/slow_log.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import (
"time"

"github.com/pingcap/errors"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -83,12 +85,19 @@ var slowQueryCols = []columnInfo{
}

func dataForSlowLog(ctx sessionctx.Context) ([][]types.Datum, error) {
return parseSlowLogFile(ctx.GetSessionVars().Location(), ctx.GetSessionVars().SlowQueryFile)
var hasProcessPriv bool
if pm := privilege.GetPrivilegeManager(ctx); pm != nil {
if pm.RequestVerification(ctx.GetSessionVars().ActiveRoles, "", "", "", mysql.ProcessPriv) {
hasProcessPriv = true
}
}
return parseSlowLogFile(ctx.GetSessionVars().Location(), ctx.GetSessionVars().SlowQueryFile,
ctx.GetSessionVars().User, hasProcessPriv)
}

// parseSlowLogFile uses to parse slow log file.
// TODO: Support parse multiple log-files.
func parseSlowLogFile(tz *time.Location, filePath string) ([][]types.Datum, error) {
func parseSlowLogFile(tz *time.Location, filePath string, user *auth.UserIdentity, hasProcessPriv bool) ([][]types.Datum, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, errors.Trace(err)
Expand All @@ -98,12 +107,21 @@ func parseSlowLogFile(tz *time.Location, filePath string) ([][]types.Datum, erro
logutil.BgLogger().Error("close slow log file failed.", zap.String("file", filePath), zap.Error(err))
}
}()
return ParseSlowLog(tz, bufio.NewReader(file))

checkValid := func(userName string) bool {
if !hasProcessPriv && user != nil && userName != user.Username {
return false
}
return true
}
return ParseSlowLog(tz, bufio.NewReader(file), checkValid)
}

type checkValidFunc func(string) bool

// ParseSlowLog exports for testing.
// TODO: optimize for parse huge log-file.
func ParseSlowLog(tz *time.Location, reader *bufio.Reader) ([][]types.Datum, error) {
func ParseSlowLog(tz *time.Location, reader *bufio.Reader, checkValid checkValidFunc) ([][]types.Datum, error) {
var rows [][]types.Datum
startFlag := false
var st *slowQueryTuple
Expand All @@ -121,11 +139,13 @@ func ParseSlowLog(tz *time.Location, reader *bufio.Reader) ([][]types.Datum, err
// Check slow log entry start flag.
if !startFlag && strings.HasPrefix(line, variable.SlowLogStartPrefixStr) {
st = &slowQueryTuple{}
err = st.setFieldValue(tz, variable.SlowLogTimeStr, line[len(variable.SlowLogStartPrefixStr):], lineNum)
valid, err := st.setFieldValue(tz, variable.SlowLogTimeStr, line[len(variable.SlowLogStartPrefixStr):], lineNum, checkValid)
if err != nil {
return rows, err
}
startFlag = true
if valid {
startFlag = true
}
continue
}

Expand All @@ -142,19 +162,24 @@ func ParseSlowLog(tz *time.Location, reader *bufio.Reader) ([][]types.Datum, err
if strings.HasSuffix(field, ":") {
field = field[:len(field)-1]
}
err = st.setFieldValue(tz, field, fieldValues[i+1], lineNum)
valid, err := st.setFieldValue(tz, field, fieldValues[i+1], lineNum, checkValid)
if err != nil {
return rows, err
}
if !valid {
startFlag = false
}
}
}
} else if strings.HasSuffix(line, variable.SlowLogSQLSuffixStr) {
// Get the sql string, and mark the start flag to false.
err = st.setFieldValue(tz, variable.SlowLogQuerySQLStr, string(hack.Slice(line)), lineNum)
_, err = st.setFieldValue(tz, variable.SlowLogQuerySQLStr, string(hack.Slice(line)), lineNum, checkValid)
if err != nil {
return rows, err
}
rows = append(rows, st.convertToDatumRow())
if checkValid == nil || checkValid(st.user) {
rows = append(rows, st.convertToDatumRow())
}
startFlag = false
} else {
startFlag = false
Expand Down Expand Up @@ -243,8 +268,8 @@ type slowQueryTuple struct {
planDigest string
}

func (st *slowQueryTuple) setFieldValue(tz *time.Location, field, value string, lineNum int) error {
var err error
func (st *slowQueryTuple) setFieldValue(tz *time.Location, field, value string, lineNum int, checkValid checkValidFunc) (valid bool, err error) {
valid = true
switch field {
case variable.SlowLogTimeStr:
st.time, err = ParseTime(value)
Expand All @@ -264,6 +289,9 @@ func (st *slowQueryTuple) setFieldValue(tz *time.Location, field, value string,
if len(field) > 1 {
st.host = fields[1]
}
if checkValid != nil {
valid = checkValid(st.user)
}
case variable.SlowLogConnIDStr:
st.connID, err = strconv.ParseUint(value, 10, 64)
case variable.SlowLogQueryTimeStr:
Expand Down Expand Up @@ -348,9 +376,9 @@ func (st *slowQueryTuple) setFieldValue(tz *time.Location, field, value string,
st.sql = value
}
if err != nil {
return errors.Wrap(err, "Parse slow log at line "+strconv.FormatInt(int64(lineNum), 10)+" failed. Field: `"+field+"`, error")
return valid, errors.Wrap(err, "Parse slow log at line "+strconv.FormatInt(int64(lineNum), 10)+" failed. Field: `"+field+"`, error")
}
return nil
return valid, err
}

func (st *slowQueryTuple) convertToDatumRow() []types.Datum {
Expand Down
28 changes: 16 additions & 12 deletions infoschema/slow_log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

func (s *testSuite) TestParseSlowLogFile(c *C) {
slowLog := bytes.NewBufferString(
slowLogStr :=
`# Time: 2019-04-28T15:24:04.309074+08:00
# Txn_start_ts: 405888132465033227
# Query_time: 0.216905
Expand All @@ -40,11 +40,15 @@ func (s *testSuite) TestParseSlowLogFile(c *C) {
# Succ: false
# Plan_digest: 60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4
# Prev_stmt: update t set i = 1;
select * from t;`)
reader := bufio.NewReader(slowLog)
select * from t;`
reader := bufio.NewReader(bytes.NewBufferString(slowLogStr))
loc, err := time.LoadLocation("Asia/Shanghai")
c.Assert(err, IsNil)
rows, err := infoschema.ParseSlowLog(loc, reader)
rows, err := infoschema.ParseSlowLog(loc, reader, func(_ string) bool { return false })
c.Assert(err, IsNil)
c.Assert(len(rows), Equals, 0)
reader = bufio.NewReader(bytes.NewBufferString(slowLogStr))
rows, err = infoschema.ParseSlowLog(loc, reader, nil)
c.Assert(err, IsNil)
c.Assert(len(rows), Equals, 1)
recordString := ""
Expand All @@ -60,7 +64,7 @@ select * from t;`)
c.Assert(expectRecordString, Equals, recordString)

// fix sql contain '# ' bug
slowLog = bytes.NewBufferString(
slowLog := bytes.NewBufferString(
`# Time: 2019-04-28T15:24:04.309074+08:00
select a# from t;
# Time: 2019-01-24T22:32:29.313255+08:00
Expand All @@ -74,7 +78,7 @@ select a# from t;
select * from t;
`)
reader = bufio.NewReader(slowLog)
_, err = infoschema.ParseSlowLog(loc, reader)
_, err = infoschema.ParseSlowLog(loc, reader, nil)
c.Assert(err, IsNil)

// test for time format compatibility.
Expand All @@ -85,7 +89,7 @@ select * from t;
select * from t;
`)
reader = bufio.NewReader(slowLog)
rows, err = infoschema.ParseSlowLog(loc, reader)
rows, err = infoschema.ParseSlowLog(loc, reader, nil)
c.Assert(err, IsNil)
c.Assert(len(rows) == 2, IsTrue)
t0Str, err := rows[0][0].ToString()
Expand All @@ -106,13 +110,13 @@ select * from t;
sql := strings.Repeat("x", int(variable.MaxOfMaxAllowedPacket+1))
slowLog.WriteString(sql)
reader = bufio.NewReader(slowLog)
_, err = infoschema.ParseSlowLog(loc, reader)
_, err = infoschema.ParseSlowLog(loc, reader, nil)
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "single line length exceeds limit: 65536")

variable.MaxOfMaxAllowedPacket = originValue
reader = bufio.NewReader(slowLog)
_, err = infoschema.ParseSlowLog(loc, reader)
_, err = infoschema.ParseSlowLog(loc, reader, nil)
c.Assert(err, IsNil)

// Add parse error check.
Expand All @@ -122,7 +126,7 @@ select * from t;
select * from t;
`)
reader = bufio.NewReader(slowLog)
_, err = infoschema.ParseSlowLog(loc, reader)
_, err = infoschema.ParseSlowLog(loc, reader, nil)
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "Parse slow log at line 2 failed. Field: `Succ`, error: strconv.ParseBool: parsing \"abc\": invalid syntax")
}
Expand Down Expand Up @@ -172,7 +176,7 @@ select * from t;`)
scanner := bufio.NewReader(slowLog)
loc, err := time.LoadLocation("Asia/Shanghai")
c.Assert(err, IsNil)
_, err = infoschema.ParseSlowLog(loc, scanner)
_, err = infoschema.ParseSlowLog(loc, scanner, nil)
c.Assert(err, IsNil)

// Test parser error.
Expand All @@ -182,7 +186,7 @@ select * from t;`)
`)

scanner = bufio.NewReader(slowLog)
_, err = infoschema.ParseSlowLog(loc, scanner)
_, err = infoschema.ParseSlowLog(loc, scanner, nil)
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "Parse slow log at line 2 failed. Field: `Txn_start_ts`, error: strconv.ParseUint: parsing \"405888132465033227#\": invalid syntax")

Expand Down
44 changes: 44 additions & 0 deletions infoschema/tables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,50 @@ func (s *testClusterTableSuite) TestSelectClusterTable(c *C) {
}
}

func (s *testClusterTableSuite) TestSelectClusterTablePrivelege(c *C) {
tk := testkit.NewTestKit(c, s.store)
slowLogFileName := "tidb-slow.log"
f, err := os.OpenFile(slowLogFileName, os.O_CREATE|os.O_WRONLY, 0644)
c.Assert(err, IsNil)
_, err = f.Write([]byte(
`# Time: 2019-02-12T19:33:56.571953+08:00
# User: user1@127.0.0.1
select * from t1;
# Time: 2019-02-12T19:33:57.571953+08:00
# User: user2@127.0.0.1
select * from t2;
# Time: 2019-02-12T19:33:58.571953+08:00
# User: user2@127.0.0.1
select * from t3;
`))
c.Assert(f.Sync(), IsNil)
c.Assert(err, IsNil)
defer os.Remove(slowLogFileName)
tk.MustExec("use information_schema")
tk.MustQuery("select count(*) from `CLUSTER_SLOW_QUERY`").Check(testkit.Rows("3"))
tk.MustQuery("select count(*) from `CLUSTER_PROCESSLIST`").Check(testkit.Rows("1"))
reafans marked this conversation as resolved.
Show resolved Hide resolved
tk.MustQuery("select * from `CLUSTER_PROCESSLIST`").Check(testkit.Rows(":10080 1 root 127.0.0.1 <nil> Query 9223372036 0 <nil> 0 "))
tk.MustExec("create user user1")
tk.MustExec("create user user2")
user1 := testkit.NewTestKit(c, s.store)
user1.MustExec("use information_schema")
c.Assert(user1.Se.Auth(&auth.UserIdentity{
Username: "user1",
Hostname: "127.0.0.1",
}, nil, nil), IsTrue)
user1.MustQuery("select count(*) from `CLUSTER_SLOW_QUERY`").Check(testkit.Rows("1"))
user1.MustQuery("select user,query from `CLUSTER_SLOW_QUERY`").Check(testkit.Rows("user1 select * from t1;"))

user2 := testkit.NewTestKit(c, s.store)
user2.MustExec("use information_schema")
c.Assert(user2.Se.Auth(&auth.UserIdentity{
Username: "user2",
Hostname: "127.0.0.1",
}, nil, nil), IsTrue)
user2.MustQuery("select count(*) from `CLUSTER_SLOW_QUERY`").Check(testkit.Rows("2"))
crazycs520 marked this conversation as resolved.
Show resolved Hide resolved
user2.MustQuery("select user,query from `CLUSTER_SLOW_QUERY` order by query").Check(testkit.Rows("user2 select * from t2;", "user2 select * from t3;"))
}

func (s *testTableSuite) TestSelectHiddenColumn(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("DROP DATABASE IF EXISTS `test_hidden`;")
Expand Down
8 changes: 8 additions & 0 deletions kv/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,14 @@ type Request struct {
Cacheable bool
// SchemaVer is for any schema-ful storage to validate schema correctness if necessary.
SchemaVar int64
// User uses to do privilege check. It is only used in TiDB cluster memory table.
User *UserIdentity
}

// UserIdentity represents username and hostname.
type UserIdentity struct {
Username string
Hostname string
}

// ResultSubset represents a result subset from a single storage unit.
Expand Down
3 changes: 3 additions & 0 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ type Manager interface {
// ConnectionVerification verifies user privilege for connection.
ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) (string, string, bool)

// GetAuthWithoutVerification uses to get auth name without verification.
GetAuthWithoutVerification(user, host string) (string, string, bool)

// DBIsVisible returns true is the database is visible to current user.
DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool

Expand Down
25 changes: 25 additions & 0 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,31 @@ func (p *UserPrivileges) GetEncodedPassword(user, host string) string {
return pwd
}

// GetAuthWithoutVerification implements the Manager interface.
func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) {
if SkipWithGrant {
p.user = user
p.host = host
success = true
return
}

mysqlPriv := p.Handle.Get()
record := mysqlPriv.connectionVerification(user, host)
if record == nil {
logutil.BgLogger().Error("get user privilege record fail",
zap.String("user", user), zap.String("host", host))
return
}

u = record.User
h = record.Host
p.user = user
p.host = h
crazycs520 marked this conversation as resolved.
Show resolved Hide resolved
success = true
return
}

// ConnectionVerification implements the Manager interface.
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) {
if SkipWithGrant {
Expand Down
Loading