Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sql/mysql_db/fbs/mysql_db.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ table PrivilegeSetTable {
privs:[int];
columns:[PrivilegeSetColumn];
}
table PrivilegeSetRoutine {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space

name:string;
privs:[int];
is_proc:bool;
}

table PrivilegeSetDatabase {
name:string;
privs:[int];
tables:[PrivilegeSetTable];
routines:[PrivilegeSetRoutine];
}

table PrivilegeSet {
Expand Down
6 changes: 5 additions & 1 deletion sql/mysql_db/mysql_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ type MySQLDb struct {

db *in_mem_table.MultiIndexedSetTable[*User]
tables_priv *in_mem_table.MultiIndexedSetTable[*User]
procs_priv *in_mem_table.MultiIndexedSetTable[*User]
global_grants *in_mem_table.MultiIndexedSetTable[*User]

//TODO: add the rest of these tables
//columns_priv *mysqlTable
//procs_priv *mysqlTable
//proxies_priv *mysqlTable
//default_roles *mysqlTable
//password_history *mysqlTable
Expand Down Expand Up @@ -120,6 +120,7 @@ func CreateEmptyMySQLDb() *MySQLDb {
// multi tables
mysqlDb.db = NewUserDBIndexedSetTable(userSet, lock, rlock)
mysqlDb.tables_priv = NewUserTablesIndexedSetTable(userSet, lock, rlock)
mysqlDb.procs_priv = NewUserProcsIndexedSetTable(userSet, lock, rlock)
mysqlDb.global_grants = NewUserGlobalGrantsIndexedSetTable(userSet, lock, rlock)

// Start the counter at 1, all new sessions will start at zero so this forces an update for any new session
Expand Down Expand Up @@ -673,6 +674,8 @@ func (db *MySQLDb) GetTableInsensitive(_ *sql.Context, tblName string) (sql.Tabl
return db.db, true, nil
case tablesPrivTblName:
return db.tables_priv, true, nil
case procsPrivTblName:
return db.procs_priv, true, nil
case replicaSourceInfoTblName:
return db.replica_source_info, true, nil
case helpTopicTableName:
Expand All @@ -694,6 +697,7 @@ func (db *MySQLDb) GetTableNames(ctx *sql.Context) ([]string, error) {
userTblName,
dbTblName,
tablesPrivTblName,
procsPrivTblName,
roleEdgesTblName,
replicaSourceInfoTblName,
helpTopicTableName,
Expand Down
26 changes: 23 additions & 3 deletions sql/mysql_db/mysql_db_load.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ func loadTable(serialTable *serial.PrivilegeSetTable) *PrivilegeSetTable {
}
}

func loadRoutine(serialRoutine *serial.PrivilegeSetRoutine) *PrivilegeSetRoutine {
return &PrivilegeSetRoutine{
name: string(serialRoutine.Name()),
privs: loadPrivilegeTypes(serialRoutine.PrivsLength(), serialRoutine.Privs),
isProc: serialRoutine.IsProc(),
}
}

func loadDatabase(serialDatabase *serial.PrivilegeSetDatabase) *PrivilegeSetDatabase {
tables := make(map[string]PrivilegeSetTable, serialDatabase.TablesLength())
for i := 0; i < serialDatabase.TablesLength(); i++ {
Expand All @@ -67,10 +75,22 @@ func loadDatabase(serialDatabase *serial.PrivilegeSetDatabase) *PrivilegeSetData
tables[table.Name()] = *table
}

routines := make(map[routineKey]PrivilegeSetRoutine, serialDatabase.RoutinesLength())
for i := 0; i < serialDatabase.RoutinesLength(); i++ {
serialRoutine := new(serial.PrivilegeSetRoutine)
if !serialDatabase.Routines(serialRoutine, i) {
continue
}
routine := loadRoutine(serialRoutine)
key := routineKey{routine.RoutineName(), routine.isProc}
routines[key] = *routine
}

return &PrivilegeSetDatabase{
name: string(serialDatabase.Name()),
privs: loadPrivilegeTypes(serialDatabase.PrivsLength(), serialDatabase.Privs),
tables: tables,
name: string(serialDatabase.Name()),
privs: loadPrivilegeTypes(serialDatabase.PrivsLength(), serialDatabase.Privs),
tables: tables,
routines: routines,
}
}

Expand Down
20 changes: 20 additions & 0 deletions sql/mysql_db/mysql_db_serialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ func serializeTables(b *flatbuffers.Builder, tables []PrivilegeSetTable) flatbuf
return serializeVectorOffsets(b, serial.PrivilegeSetDatabaseStartTablesVector, offsets)
}

func serializeRoutines(b *flatbuffers.Builder, routines []PrivilegeSetRoutine) flatbuffers.UOffsetT {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: inconsistent spacing in some of these functions, some skip a line, many don't

offsets := make([]flatbuffers.UOffsetT, len(routines))
for i, routine := range routines {
name := b.CreateString(routine.RoutineName())
privs := serializePrivilegeTypes(b, serial.PrivilegeSetTableStartPrivsVector, routine.ToSlice())

serial.PrivilegeSetRoutineStart(b)
serial.PrivilegeSetRoutineAddName(b, name)
serial.PrivilegeSetRoutineAddPrivs(b, privs)
serial.PrivilegeSetRoutineAddIsProc(b, routine.isProc)

offsets[len(offsets)-i-1] = serial.PrivilegeSetRoutineEnd(b)
}

return serializeVectorOffsets(b, serial.PrivilegeSetDatabaseStartRoutinesVector, offsets)
}

// serializeDatabases writes the given Privilege Set Databases into the flatbuffer Builder, and returns the offset
func serializeDatabases(b *flatbuffers.Builder, databases []PrivilegeSetDatabase) flatbuffers.UOffsetT {
// Write database variables, and save offsets
Expand All @@ -105,11 +123,13 @@ func serializeDatabases(b *flatbuffers.Builder, databases []PrivilegeSetDatabase
name := b.CreateString(database.Name())
privs := serializePrivilegeTypes(b, serial.PrivilegeSetDatabaseStartPrivsVector, database.ToSlice())
tables := serializeTables(b, database.getTables())
routines := serializeRoutines(b, database.getRoutines())

serial.PrivilegeSetDatabaseStart(b)
serial.PrivilegeSetDatabaseAddName(b, name)
serial.PrivilegeSetDatabaseAddPrivs(b, privs)
serial.PrivilegeSetDatabaseAddTables(b, tables)
serial.PrivilegeSetDatabaseAddRoutines(b, routines)
offsets[len(offsets)-i-1] = serial.PrivilegeSetDatabaseEnd(b)
}

Expand Down
26 changes: 26 additions & 0 deletions sql/mysql_db/privilege_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,12 @@ func (ps PrivilegeSetDatabase) HasPrivileges() bool {
return true
}
}
for _, routineSet := range ps.routines {
if routineSet.HasPrivileges() {
return true
}
}

return false
}

Expand Down Expand Up @@ -512,6 +518,26 @@ func (ps PrivilegeSetDatabase) GetRoutines() []sql.PrivilegeSetRoutine {
return routineSets
}

func (ps PrivilegeSetDatabase) getRoutines() []PrivilegeSetRoutine {
if ps.routines == nil || len(ps.routines) == 0 {
return []PrivilegeSetRoutine{}
}

routineSets := make([]PrivilegeSetRoutine, 0, len(ps.routines))
for _, routine := range ps.routines {
routineSets = append(routineSets, routine)
}

sort.Slice(routineSets, func(i, j int) bool {
if routineSets[i].RoutineName() != routineSets[j].RoutineType() {
return routineSets[i].RoutineName() < routineSets[j].RoutineName()
}
return routineSets[i].RoutineType() < routineSets[j].RoutineType()
})

return routineSets
}

// Equals returns whether the given set of privileges is equivalent to the calling set.
func (ps PrivilegeSetDatabase) Equals(otherPsd sql.PrivilegeSetDatabase) bool {
otherPs := otherPsd.(PrivilegeSetDatabase)
Expand Down
188 changes: 188 additions & 0 deletions sql/mysql_db/procs_priv.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package mysql_db

import (
"strings"
"sync"
"time"

"github.com/dolthub/vitess/go/sqltypes"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/in_mem_table"
"github.com/dolthub/go-mysql-server/sql/types"
)

const procsPrivTblName = "procs_priv"

var procsPrivTblSchema = buildProcsPrivSchema()

func NewUserProcsIndexedSetTable(set in_mem_table.IndexedSet[*User], lock, rlock sync.Locker) *in_mem_table.MultiIndexedSetTable[*User] {
table := in_mem_table.NewMultiIndexedSetTable[*User](
procsPrivTblName,
procsPrivTblSchema,
sql.Collation_utf8mb3_bin,
set,
in_mem_table.MultiValueOps[*User]{
ToRows: UserToProcsPrivRows,
FromRow: UserFromProcsPrivRow,
AddRow: UserAddProcsPrivRow,
DeleteRow: UserRemoveProcsPrivRow,
},
lock,
rlock,
)
return table
}

func newEmptyRow(ctx *sql.Context) sql.Row {
row := make(sql.Row, len(procsPrivTblSchema))
var err error
for i, col := range procsPrivTblSchema {
row[i], err = col.Default.Eval(ctx, nil)
if err != nil {
panic(err) // Schema is static. New rows should never fail.
}
}
return row
}

func UserToProcsPrivRows(ctx *sql.Context, user *User) ([]sql.Row, error) {

var ans []sql.Row
for _, dbSet := range user.PrivilegeSet.GetDatabases() {
for _, routineSet := range dbSet.GetRoutines() {
if routineSet.Count() == 0 {
continue
}
row := newEmptyRow(ctx)

row[procsPrivTblColIndex_Host] = user.Host
row[procsPrivTblColIndex_Db] = dbSet.Name()
row[procsPrivTblColIndex_User] = user.User
row[procsPrivTblColIndex_RoutineName] = routineSet.RoutineName()
row[procsPrivTblColIndex_RoutineType] = routineSet.RoutineType()

var privs []string
for _, priv := range routineSet.ToSlice() {
switch priv {
case sql.PrivilegeType_Execute:
privs = append(privs, "Execute")
case sql.PrivilegeType_GrantOption:
privs = append(privs, "Grant") // MySQL prints just "Grant", and not "Grant Option"
case sql.PrivilegeType_AlterRoutine:
privs = append(privs, "Alter Routine")
}
}
privsStr := strings.Join(privs, ",")
row[procsPrivTblColIndex_ProcPriv] = privsStr

ans = append(ans, row)
}
}

return ans, nil
}

func UserFromProcsPrivRow(ctx *sql.Context, row sql.Row) (*User, error) {
panic("implement me") // Currently inaccessible code path.
}

func UserAddProcsPrivRow(ctx *sql.Context, row sql.Row, user *User) (*User, error) {
panic("implement me") // Currently inaccessible code path.
}

func UserRemoveProcsPrivRow(ctx *sql.Context, row sql.Row, user *User) (*User, error) {
panic("implement me") // Currently inaccessible code path.
}

// buildProcsPrivSchema builds the schema for the "procs_priv" Grant Table.
// MySQL Table for reference:
//
// mysql> show create table mysql.procs_priv:
//
// CREATE TABLE `procs_priv` (
//
// `Host` char(255) CHARACTER SET ascii COLLATE ascii_general_ci NOT NULL DEFAULT '',
// `Db` char(64) COLLATE utf8mb3_bin NOT NULL DEFAULT '',
// `User` char(32) COLLATE utf8mb3_bin NOT NULL DEFAULT '',
// `Routine_name` char(64) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL DEFAULT '',
// `Routine_type` enum('FUNCTION','PROCEDURE') COLLATE utf8mb3_bin NOT NULL,
// `Grantor` varchar(288) COLLATE utf8mb3_bin NOT NULL DEFAULT '',
// `Proc_priv` set('Execute','Alter Routine','Grant') CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL DEFAULT '',
// `Timestamp` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
// PRIMARY KEY (`Host`,`User`,`Db`,`Routine_name`,`Routine_type`),
// KEY `Grantor` (`Grantor`
// )
func buildProcsPrivSchema() sql.Schema {
len255_asciii := types.MustCreateString(sqltypes.Char, 255, sql.Collation_ascii_general_ci)
len64_utf8_bin := types.MustCreateString(sqltypes.Char, 64, sql.Collation_utf8_bin)
len64_utf8_gen := types.MustCreateString(sqltypes.Char, 64, sql.Collation_utf8_general_ci)
len32_utf8 := types.MustCreateString(sqltypes.Char, 32, sql.Collation_utf8_bin)
routine_types_enum := types.MustCreateEnumType([]string{"FUNCTION", "PROCEDURE"}, sql.Collation_utf8_bin)
varchar288_utf8 := types.MustCreateString(sqltypes.VarChar, 288, sql.Collation_utf8_bin)
set_privs := types.MustCreateSetType([]string{"Execute", "Alter Routine", "Grant"}, sql.Collation_utf8_general_ci)

return sql.Schema{
columnTemplate("Host", procsPrivTblName, true, &sql.Column{
Type: len255_asciii,
Default: mustDefault(expression.NewLiteral("", len255_asciii), len255_asciii, true, false),
Nullable: false}),
columnTemplate("Db", procsPrivTblName, true, &sql.Column{
Type: len64_utf8_bin,
Default: mustDefault(expression.NewLiteral("", len64_utf8_bin), len64_utf8_bin, true, false),
Nullable: false}),
columnTemplate("User", procsPrivTblName, true, &sql.Column{
Type: len32_utf8,
Default: mustDefault(expression.NewLiteral("", len32_utf8), len32_utf8, true, false),
Nullable: false}),
columnTemplate("Routine_name", procsPrivTblName, true, &sql.Column{
Type: len64_utf8_gen,
Default: mustDefault(expression.NewLiteral("", len64_utf8_gen), len64_utf8_gen, true, false),
Nullable: false}),
columnTemplate("Routine_type", procsPrivTblName, true, &sql.Column{
Type: routine_types_enum,
Default: mustDefault(expression.NewLiteral("PROCEDURE", routine_types_enum), routine_types_enum, true, false),
Nullable: false}),
columnTemplate("Grantor", procsPrivTblName, false, &sql.Column{
Type: varchar288_utf8,
Default: mustDefault(expression.NewLiteral("", varchar288_utf8), varchar288_utf8, true, false),
Nullable: false}),
columnTemplate("Proc_priv", procsPrivTblName, false, &sql.Column{
Type: set_privs,
Default: mustDefault(expression.NewLiteral("", set_privs), set_privs, true, false),
Nullable: false}),
columnTemplate("Timestamp", tablesPrivTblName, false, &sql.Column{
Type: types.Timestamp,
Default: mustDefault(expression.NewLiteral(time.Unix(1, 0).UTC(), types.Timestamp), types.Timestamp, true, false),
Nullable: false}),
}
}

// The column indexes of the "procs_priv" Grant Table.
// https://dev.mysql.com/doc/refman/8.0/en/grant-tables.html#grant-tables-procs-priv
// https://mariadb.com/kb/en/mysqlprocs_priv-table/
const (
procsPrivTblColIndex_Host int = iota
procsPrivTblColIndex_Db
procsPrivTblColIndex_User
procsPrivTblColIndex_RoutineName
procsPrivTblColIndex_RoutineType
procsPrivTblColIndex_Grantor
procsPrivTblColIndex_ProcPriv
procsPrivTblColIndex_Timestamp
)
Loading