diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 88dbd0c7f21cf..3448c62786ec2 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2096,7 +2096,7 @@ func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { if dbName == "" { dbName = b.ctx.GetSessionVars().CurrentDB } - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, dbName, t.Name.L, "") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "") } if sel.Where != nil { @@ -2208,6 +2208,10 @@ func (b *PlanBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A p = np newList = append(newList, &expression.Assignment{Col: col, Expr: newExpr}) } + for _, assign := range newList { + col := assign.Col + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, col.DBName.L, col.TblName.L, "") + } return newList, p, nil } diff --git a/session/session_test.go b/session/session_test.go index b8e46644cee9e..dcf734d7be818 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -15,6 +15,7 @@ package session_test import ( "fmt" + "strings" "sync" "sync/atomic" "time" @@ -2226,3 +2227,30 @@ func (s *testSessionSuite) TestSetGroupConcatMaxLen(c *C) { _, err = tk.Exec("set @@group_concat_max_len='hello'") c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue, Commentf("err %v", err)) } + +func (s *testSessionSuite) TestUpdatePrivilege(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("create table t1 (id int);") + tk.MustExec("create table t2 (id int);") + tk.MustExec("insert into t1 values (1);") + tk.MustExec("insert into t2 values (2);") + tk.MustExec("create user xxx;") + tk.MustExec("grant all on test.t1 to xxx;") + tk.MustExec("grant select on test.t2 to xxx;") + tk.MustExec("flush privileges;") + + tk1 := testkit.NewTestKitWithInit(c, s.store) + c.Assert(tk1.Se.Auth(&auth.UserIdentity{Username: "xxx", Hostname: "localhost"}, + []byte(""), + []byte("")), IsTrue) + + _, err := tk1.Exec("update t2 set id = 666 where id = 1;") + c.Assert(err, NotNil) + c.Assert(strings.Contains(err.Error(), "privilege check fail"), IsTrue) + + // Cover a bug that t1 and t2 both require update privilege. + // In fact, the privlege check for t1 should be update, and for t2 should be select. + _, err = tk1.Exec("update t1,t2 set t1.id = t2.id;") + c.Assert(err, IsNil) +}