Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.
Merged
45 changes: 38 additions & 7 deletions cmd/frontend/internal/bg/update_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,61 @@ import (
func UpdatePermissions(ctx context.Context, logger log.Logger, db database.DB) {
scopedLog := logger.Scoped("permission_update", "Updates the permission in the database based on the rbac schema configuration.")
err := db.WithTransact(ctx, func(tx database.DB) error {
pstore := tx.Permissions()
permissionStore := tx.Permissions()
roleStore := tx.Roles()
rolePermissionStore := tx.RolePermissions()

dbPerms, err := pstore.FetchAll(ctx)
dbPerms, err := permissionStore.FetchAll(ctx)
if err != nil {
return errors.Wrap(err, "fetching permissions from database")
}

toBeAdded, toBeDeleted := rbac.ComparePermissions(dbPerms)
toBeAdded, toBeDeleted := rbac.ComparePermissions(dbPerms, rbac.RBACSchema)
scopedLog.Info("RBAC Permissions update", log.Int("added", len(toBeAdded)), log.Int("deleted", len(toBeDeleted)))

if len(toBeDeleted) > 0 {
// We delete all the permissions that need to be deleted from the database
err = pstore.BulkDelete(ctx, toBeDeleted)
// We delete all the permissions that need to be deleted from the database. The role <> permissions are
// automatically deleted: https://app.golinks.io/role_permissions-permission_id_cascade.
err = permissionStore.BulkDelete(ctx, toBeDeleted)
if err != nil {
return errors.Wrap(err, "deleting redundant permissions")
}
}

if len(toBeAdded) > 0 {
// Adding new permissions to the database
_, err = pstore.BulkCreate(ctx, toBeAdded)
// Adding new permissions to the database. This permissions will be assigned to the System roles
// (USER and SITE_ADMINISTRATOR).
permissions, err := permissionStore.BulkCreate(ctx, toBeAdded)
if err != nil {
return errors.Wrap(err, "creating new permissions")
}

// Currently, we have only two system roles so we can just list the first two. In the future,
// it might be worth creating a new method called `FetchAll` or `ListWithoutPagination` to
// retrieve all system roles, but since we know currently there won't be more than two system
// roles at any given point in time, then this works.
firstParam := 2
systemRoles, err := roleStore.List(ctx, database.RolesListOptions{
PaginationArgs: &database.PaginationArgs{
First: &firstParam,
},
System: true,
})
if err != nil {
return errors.Wrap(err, "fetching system roles")
}

for _, permission := range permissions {
for _, role := range systemRoles {
_, err := rolePermissionStore.Create(ctx, database.CreateRolePermissionOpts{
PermissionID: permission.ID,
RoleID: role.ID,
})
if err != nil {
return errors.Wrapf(err, "assigning permission to role: %s", role.Name)
}
}
}
}

return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,32 @@ import (
"github.com/sourcegraph/sourcegraph/internal/types"
)

type roleAssignmentMigrator struct {
type userRoleAssignmentMigrator struct {
store *basestore.Store
batchSize int
}

func NewRoleAssignmentMigrator(store *basestore.Store, batchSize int) *roleAssignmentMigrator {
return &roleAssignmentMigrator{
func NewUserRoleAssignmentMigrator(store *basestore.Store, batchSize int) *userRoleAssignmentMigrator {
return &userRoleAssignmentMigrator{
store: store,
batchSize: batchSize,
}
}

var _ oobmigration.Migrator = &roleAssignmentMigrator{}
var _ oobmigration.Migrator = &userRoleAssignmentMigrator{}

func (m *roleAssignmentMigrator) ID() int { return 19 }
func (m *roleAssignmentMigrator) Interval() time.Duration { return time.Second * 10 }
func (m *userRoleAssignmentMigrator) ID() int { return 19 }
func (m *userRoleAssignmentMigrator) Interval() time.Duration { return time.Second * 10 }

// Progress returns the percentage (ranged [0, 1]) of users who have a system role (USER or SITE_ADMINISTRATOR) assigned.
func (m *roleAssignmentMigrator) Progress(ctx context.Context, _ bool) (float64, error) {
progress, _, err := basestore.ScanFirstFloat(m.store.Query(ctx, sqlf.Sprintf(roleAssignmentMigratorProgressQuery)))
func (m *userRoleAssignmentMigrator) Progress(ctx context.Context, _ bool) (float64, error) {
progress, _, err := basestore.ScanFirstFloat(m.store.Query(ctx, sqlf.Sprintf(userRoleAssignmentMigratorProgressQuery)))
return progress, err
}

// This query checks the total number of user_roles in the database vs. the sum of the total number of users and the total number of users who are site_admin.
// We use a CTE here to only check for system roles (e.g USER and SITE_ADMINISTRATOR) since those are the two system roles that should be available on every instance.
const roleAssignmentMigratorProgressQuery = `
const userRoleAssignmentMigratorProgressQuery = `
WITH system_roles AS MATERIALIZED (
SELECT id FROM roles WHERE system
)
Expand All @@ -49,11 +49,11 @@ FROM
(SELECT COUNT(1) AS count FROM user_roles WHERE role_id IN (SELECT id FROM system_roles)) ur1
`

func (m *roleAssignmentMigrator) Up(ctx context.Context) (err error) {
func (m *userRoleAssignmentMigrator) Up(ctx context.Context) (err error) {
return m.store.Exec(ctx, sqlf.Sprintf(userRolesMigratorUpQuery, string(types.UserSystemRole), string(types.SiteAdministratorSystemRole), m.batchSize))
}

func (m *roleAssignmentMigrator) Down(ctx context.Context) error {
func (m *userRoleAssignmentMigrator) Down(ctx context.Context) error {
// non-destructive
return nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ import (
"github.com/sourcegraph/sourcegraph/internal/types"
)

func TestRoleAssignmentMigrator(t *testing.T) {
func TestUserRoleAssignmentMigrator(t *testing.T) {
ctx := context.Background()
logger := logtest.Scoped(t)
db := database.NewDB(logger, dbtest.NewDB(logger, t))
store := basestore.NewWithHandle(db.Handle())

migrator := NewRoleAssignmentMigrator(store, 5)
migrator := NewUserRoleAssignmentMigrator(store, 5)
progress, err := migrator.Progress(ctx, false)
assert.NoError(t, err)

Expand Down
2 changes: 1 addition & 1 deletion internal/oobmigration/migrations/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type migratorDependencies struct {
func registerOSSMigrators(runner *oobmigration.Runner, noDelay bool, deps migratorDependencies) error {
return RegisterAll(runner, noDelay, []TaggedMigrator{
batches.NewExternalServiceWebhookMigratorWithDB(deps.store, deps.keyring.ExternalServiceKey, 50),
batches.NewRoleAssignmentMigrator(deps.store, 500),
batches.NewUserRoleAssignmentMigrator(deps.store, 500),
})
}

Expand Down
4 changes: 3 additions & 1 deletion internal/oobmigration/oobmigrations.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,14 @@
deprecated_version_minor: 4
- id: 19
team: batch-changes
component: db.user_roles
component: frontend-db.user_roles
description: Assigns roles to existing users
non_destructive: true
is_enterprise: false
introduced_version_major: 4
introduced_version_minor: 5
deprecated_version_major: 4
deprecated_version_minor: 6
- id: 20
team: code-intelligence
component: codeintel-db.lsif_data_*
Expand Down
31 changes: 17 additions & 14 deletions internal/rbac/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
//go:embed schema.yaml
var schema embed.FS

var schemaYaml = func() Schema {
var RBACSchema = func() Schema {
contents, err := schema.ReadFile("schema.yaml")
if err != nil {
panic(fmt.Sprintf("malformed rbac schema definition: %s", err.Error()))
Expand All @@ -29,18 +29,19 @@ var schemaYaml = func() Schema {

// ComparePermissions takes two slices of permissions (one from the database and another from the schema file)
// and extracts permissions that need to be added / deleted in the database based on those contained in the schema file.
func ComparePermissions(dbPerms []*types.Permission) (added []database.CreatePermissionOpts, deleted []database.DeletePermissionOpts) {
// Create map to hold the items in both arrays
func ComparePermissions(dbPerms []*types.Permission, schemaPerms Schema) (added []database.CreatePermissionOpts, deleted []database.DeletePermissionOpts) {
// Create map to hold the union of both permissions in the database and those in the schema file. `internal/rbac/schema.yaml`
ps := make(map[string]struct {
count int
id int32
})

// save all database permissions to the map
for _, p := range dbPerms {
// Since dbPerms contain an ID we save the ID which will be used
// if we need to delete
ps[p.Namespace+p.Action] = struct {
currentPerm := p.DisplayName()
// Since dbPerms contain an ID we save the ID which will be used to delete redundant permissions.
// This also ensures all permissions are unique and we never have duplicate permissions.
ps[currentPerm] = struct {
count int
id int32
}{
Expand All @@ -49,29 +50,31 @@ func ComparePermissions(dbPerms []*types.Permission) (added []database.CreatePer
}
}

var schemaPerms []*types.Permission
var parsedSchemaPerms []*types.Permission

for _, n := range schemaYaml.Namespaces {
for _, n := range schemaPerms.Namespaces {
for _, a := range n.Actions {
schemaPerms = append(schemaPerms, &types.Permission{
parsedSchemaPerms = append(parsedSchemaPerms, &types.Permission{
Namespace: n.Name,
Action: a,
})
}
}

// Check items in schema file to see which exists in the database
for _, p := range schemaPerms {
// If item is not in map, it means it doesn't exist in the database so we
// add it to the `added` slice.
if perm, ok := ps[p.Namespace+p.Action]; !ok {
for _, p := range parsedSchemaPerms {
currentPerm := p.DisplayName()

if perm, ok := ps[currentPerm]; !ok {
// If item is not in map, it means it doesn't exist in the database so we
// add it to the `added` slice.
added = append(added, database.CreatePermissionOpts{
Namespace: p.Namespace,
Action: p.Action,
})
} else {
// If item is in map, it means it already exist in the database
ps[p.Namespace+p.Action] = struct {
ps[currentPerm] = struct {
count int
id int32
}{
Expand Down
126 changes: 126 additions & 0 deletions internal/rbac/permissions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package rbac

import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"

"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/types"
)

func TestComparePermissions(t *testing.T) {
dbPerms := []*types.Permission{
{ID: 1, Namespace: "TEST-NAMESPACE", Action: "READ"},
{ID: 2, Namespace: "TEST-NAMESPACE", Action: "WRITE"},
{ID: 3, Namespace: "TEST-NAMESPACE-2", Action: "READ"},
{ID: 4, Namespace: "TEST-NAMESPACE-2", Action: "WRITE"},
{ID: 5, Namespace: "TEST-NAMESPACE-3", Action: "READ"},
}

t.Run("no changes to permissions", func(t *testing.T) {
schemaPerms := Schema{
Namespaces: []Namespace{
{Name: "TEST-NAMESPACE", Actions: []string{"READ", "WRITE"}},
{Name: "TEST-NAMESPACE-2", Actions: []string{"READ", "WRITE"}},
{Name: "TEST-NAMESPACE-3", Actions: []string{"READ"}},
},
}

added, deleted := ComparePermissions(dbPerms, schemaPerms)

assert.Len(t, added, 0)
assert.Len(t, deleted, 0)
})

t.Run("permissions deleted", func(t *testing.T) {
schemaPerms := Schema{
Namespaces: []Namespace{
{Name: "TEST-NAMESPACE", Actions: []string{"READ", "WRITE"}},
{Name: "TEST-NAMESPACE-2", Actions: []string{"READ"}},
},
}

want := []database.DeletePermissionOpts{
{ID: int32(4)},
{ID: int32(5)},
}

added, deleted := ComparePermissions(dbPerms, schemaPerms)

assert.Len(t, added, 0)
assert.Len(t, deleted, 2)
if diff := cmp.Diff(want, deleted, cmpopts.SortSlices(sortDeletePermissionOptSlice)); diff != "" {
t.Error(diff)
}
})

t.Run("permissions added", func(t *testing.T) {
schemaPerms := Schema{
Namespaces: []Namespace{
{Name: "TEST-NAMESPACE", Actions: []string{"READ", "WRITE"}},
{Name: "TEST-NAMESPACE-2", Actions: []string{"READ", "WRITE", "EXECUTE"}},
{Name: "TEST-NAMESPACE-3", Actions: []string{"READ", "WRITE"}},
{Name: "TEST-NAMESPACE-4", Actions: []string{"READ", "WRITE"}},
},
}

want := []database.CreatePermissionOpts{
{Namespace: "TEST-NAMESPACE-2", Action: "EXECUTE"},
{Namespace: "TEST-NAMESPACE-3", Action: "WRITE"},
{Namespace: "TEST-NAMESPACE-4", Action: "READ"},
{Namespace: "TEST-NAMESPACE-4", Action: "WRITE"},
}

added, deleted := ComparePermissions(dbPerms, schemaPerms)

assert.Len(t, added, 4)
assert.Len(t, deleted, 0)
if diff := cmp.Diff(want, added); diff != "" {
t.Error(diff)
}
})

t.Run("permissions deleted and added", func(t *testing.T) {
schemaPerms := Schema{
Namespaces: []Namespace{
{Name: "TEST-NAMESPACE", Actions: []string{"READ"}},
{Name: "TEST-NAMESPACE-2", Actions: []string{"READ", "WRITE", "EXECUTE"}},
{Name: "TEST-NAMESPACE-3", Actions: []string{"WRITE"}},
{Name: "TEST-NAMESPACE-4", Actions: []string{"READ", "WRITE"}},
},
}

wantAdded := []database.CreatePermissionOpts{
{Namespace: "TEST-NAMESPACE-2", Action: "EXECUTE"},
{Namespace: "TEST-NAMESPACE-3", Action: "WRITE"},
{Namespace: "TEST-NAMESPACE-4", Action: "READ"},
{Namespace: "TEST-NAMESPACE-4", Action: "WRITE"},
}

wantDeleted := []database.DeletePermissionOpts{
// Represents TEST-NAMESPACE-3#READ
{ID: 5},
// Represents TEST-NAMESPACE#WRITE
{ID: 2},
}

// do stuff
added, deleted := ComparePermissions(dbPerms, schemaPerms)

assert.Len(t, added, 4)
if diff := cmp.Diff(wantAdded, added); diff != "" {
t.Error(diff)
}

assert.Len(t, deleted, 2)
less := func(a, b database.DeletePermissionOpts) bool { return a.ID < b.ID }
if diff := cmp.Diff(wantDeleted, deleted, cmpopts.SortSlices(less)); diff != "" {
t.Error(diff)
}
})
}

func sortDeletePermissionOptSlice(a, b database.DeletePermissionOpts) bool { return a.ID < b.ID }
2 changes: 1 addition & 1 deletion internal/rbac/schema.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
namespaces:
- name: BATCHCHANGES
- name: BATCH_CHANGES
actions:
- READ
- WRITE
7 changes: 7 additions & 0 deletions internal/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,13 @@ type Permission struct {
CreatedAt time.Time
}

// DisplayName returns an human-readable string for permissions.
func (p *Permission) DisplayName() string {
// Based on the zanzibar representation for data relations:
// <namespace>:<object_id>#<relation>@<user_id | user_group>
return fmt.Sprintf("%s#%s", p.Namespace, p.Action)
}

type RolePermission struct {
RoleID int32
PermissionID int32
Expand Down