Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion internal/diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ type policyDiff struct {
// rlsChange represents enabling/disabling Row Level Security on a table
type rlsChange struct {
Table *ir.Table
Enabled bool // true to enable, false to disable
Enabled *bool // nil = no change, true = enable, false = disable
Forced *bool // nil = no change, true = force, false = no force
}

// GenerateMigration compares two IR schemas and returns the SQL differences
Expand Down
53 changes: 38 additions & 15 deletions internal/diff/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,50 @@ func generateCreatePoliciesSQL(policies []*ir.RLSPolicy, targetSchema string, co
}
}

// generateRLSChangesSQL generates RLS enable/disable statements
// generateRLSChangesSQL generates RLS enable/disable and force statements
func generateRLSChangesSQL(changes []*rlsChange, targetSchema string, collector *diffCollector) {
for _, change := range changes {
var sql string
tableName := qualifyEntityName(change.Table.Schema, change.Table.Name, targetSchema)
if change.Enabled {
sql = fmt.Sprintf("ALTER TABLE %s ENABLE ROW LEVEL SECURITY;", tableName)
} else {
sql = fmt.Sprintf("ALTER TABLE %s DISABLE ROW LEVEL SECURITY;", tableName)
}

// Create context for this statement
context := &diffContext{
Type: DiffTypeTableRLS,
Operation: DiffOperationAlter,
Path: fmt.Sprintf("%s.%s", change.Table.Schema, change.Table.Name),
Source: change,
CanRunInTransaction: true,
// Handle ENABLE/DISABLE changes
if change.Enabled != nil {
var sql string
if *change.Enabled {
sql = fmt.Sprintf("ALTER TABLE %s ENABLE ROW LEVEL SECURITY;", tableName)
} else {
sql = fmt.Sprintf("ALTER TABLE %s DISABLE ROW LEVEL SECURITY;", tableName)
}

context := &diffContext{
Type: DiffTypeTableRLS,
Operation: DiffOperationAlter,
Path: fmt.Sprintf("%s.%s", change.Table.Schema, change.Table.Name),
Source: change,
CanRunInTransaction: true,
}

collector.collect(context, sql)
}

collector.collect(context, sql)
// Handle FORCE/NO FORCE changes
if change.Forced != nil {
var sql string
if *change.Forced {
sql = fmt.Sprintf("ALTER TABLE %s FORCE ROW LEVEL SECURITY;", tableName)
} else {
sql = fmt.Sprintf("ALTER TABLE %s NO FORCE ROW LEVEL SECURITY;", tableName)
}

context := &diffContext{
Type: DiffTypeTableRLS,
Operation: DiffOperationAlter,
Path: fmt.Sprintf("%s.%s", change.Table.Schema, change.Table.Name),
Source: change,
CanRunInTransaction: true,
}

collector.collect(context, sql)
}
}
}

Expand Down
91 changes: 67 additions & 24 deletions internal/diff/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,20 @@ func diffTables(oldTable, newTable *ir.Table, targetSchema string) *tableDiff {
}
}

// Check for RLS enable/disable changes
if oldTable.RLSEnabled != newTable.RLSEnabled {
diff.RLSChanges = append(diff.RLSChanges, &rlsChange{
Table: newTable,
Enabled: newTable.RLSEnabled,
})
// Check for RLS enable/disable and force changes
if oldTable.RLSEnabled != newTable.RLSEnabled || oldTable.RLSForced != newTable.RLSForced {
change := &rlsChange{
Table: newTable,
}
if oldTable.RLSEnabled != newTable.RLSEnabled {
change.Enabled = &newTable.RLSEnabled
}
// Only track FORCE changes if RLS is not being disabled
// (disabling RLS implicitly clears FORCE, making NO FORCE redundant)
if oldTable.RLSForced != newTable.RLSForced && newTable.RLSEnabled {
change.Forced = &newTable.RLSForced
}
diff.RLSChanges = append(diff.RLSChanges, change)
}

// Check for table comment changes
Expand Down Expand Up @@ -417,9 +425,18 @@ func generateCreateTablesSQL(
}
generateCreateIndexesSQL(indexes, targetSchema, collector)

// Handle RLS enable changes (before creating policies) - only for diff scenarios
if table.RLSEnabled {
rlsChanges := []*rlsChange{{Table: table, Enabled: true}}
// Handle RLS enable/force changes (before creating policies) - only for diff scenarios
if table.RLSEnabled || table.RLSForced {
change := &rlsChange{Table: table}
if table.RLSEnabled {
enabled := true
change.Enabled = &enabled
}
if table.RLSForced {
forced := true
change.Forced = &forced
}
rlsChanges := []*rlsChange{change}
generateRLSChangesSQL(rlsChanges, targetSchema, collector)
}

Expand Down Expand Up @@ -894,24 +911,50 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector
// Handle RLS changes
for _, rlsChange := range td.RLSChanges {
tableName := getTableNameWithSchema(td.Table.Schema, td.Table.Name, targetSchema)
var sql string
var operation DiffOperation
if rlsChange.Enabled {
sql = fmt.Sprintf("ALTER TABLE %s ENABLE ROW LEVEL SECURITY;", tableName)
operation = DiffOperationCreate
} else {
sql = fmt.Sprintf("ALTER TABLE %s DISABLE ROW LEVEL SECURITY;", tableName)
operation = DiffOperationDrop

// Handle ENABLE/DISABLE changes
if rlsChange.Enabled != nil {
var sql string
var operation DiffOperation
if *rlsChange.Enabled {
sql = fmt.Sprintf("ALTER TABLE %s ENABLE ROW LEVEL SECURITY;", tableName)
operation = DiffOperationCreate
} else {
sql = fmt.Sprintf("ALTER TABLE %s DISABLE ROW LEVEL SECURITY;", tableName)
operation = DiffOperationDrop
}

context := &diffContext{
Type: DiffTypeTableRLS,
Operation: operation,
Path: fmt.Sprintf("%s.%s", td.Table.Schema, td.Table.Name),
Source: rlsChange,
CanRunInTransaction: true,
}
collector.collect(context, sql)
}

context := &diffContext{
Type: DiffTypeTableRLS,
Operation: operation,
Path: fmt.Sprintf("%s.%s", td.Table.Schema, td.Table.Name),
Source: rlsChange,
CanRunInTransaction: true,
// Handle FORCE/NO FORCE changes
if rlsChange.Forced != nil {
var sql string
var operation DiffOperation
if *rlsChange.Forced {
sql = fmt.Sprintf("ALTER TABLE %s FORCE ROW LEVEL SECURITY;", tableName)
operation = DiffOperationAlter
} else {
sql = fmt.Sprintf("ALTER TABLE %s NO FORCE ROW LEVEL SECURITY;", tableName)
operation = DiffOperationAlter
}

context := &diffContext{
Type: DiffTypeTableRLS,
Operation: operation,
Path: fmt.Sprintf("%s.%s", td.Table.Schema, td.Table.Name),
Source: rlsChange,
CanRunInTransaction: true,
}
collector.collect(context, sql)
}
collector.collect(context, sql)
}

// Drop policies - already sorted by the Diff operation
Expand Down
19 changes: 6 additions & 13 deletions ir/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -1518,25 +1518,18 @@ func (i *Inspector) decodeTriggerLevel(tgtype int16) TriggerLevel {

func (i *Inspector) buildRLSPolicies(ctx context.Context, schema *IR, targetSchema string) error {
// Get RLS enabled tables for the target schema
rlsTables, err := i.queries.GetRLSTablesForSchema(ctx, sql.NullString{String: targetSchema, Valid: true})
rlsTables, err := i.queries.GetRLSTablesForSchema(ctx, targetSchema)
if err != nil {
return err
}

// Mark tables as RLS enabled
// Mark tables as RLS enabled/forced
for _, rlsTable := range rlsTables {
schemaName := ""
if rlsTable.Schemaname.Valid {
schemaName = rlsTable.Schemaname.String
}
tableName := ""
if rlsTable.Tablename.Valid {
tableName = rlsTable.Tablename.String
}

dbSchema := schema.getOrCreateSchema(schemaName)
if table, exists := dbSchema.Tables[tableName]; exists {
dbSchema := schema.getOrCreateSchema(rlsTable.Schemaname)
if table, exists := dbSchema.Tables[rlsTable.Tablename]; exists {
// Query filters by rowsecurity = true, so this is always true
table.RLSEnabled = true
table.RLSForced = rlsTable.Rowforced
}
}

Expand Down
1 change: 1 addition & 0 deletions ir/ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type Table struct {
Indexes map[string]*Index `json:"indexes"` // index_name -> Index
Triggers map[string]*Trigger `json:"triggers"` // trigger_name -> Trigger
RLSEnabled bool `json:"rls_enabled"`
RLSForced bool `json:"rls_forced"`
Policies map[string]*RLSPolicy `json:"policies"` // policy_name -> RLSPolicy
Dependencies []TableDependency `json:"dependencies"`
Comment string `json:"comment,omitempty"`
Expand Down
44 changes: 26 additions & 18 deletions ir/queries/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -674,16 +674,20 @@ ORDER BY vtu.view_schema, vtu.view_name, vtu.table_schema, vtu.table_name;

-- GetRLSTables retrieves tables with row level security enabled
-- name: GetRLSTables :many
SELECT
schemaname,
tablename
FROM pg_tables
WHERE
schemaname NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
AND schemaname NOT LIKE 'pg_temp_%'
AND schemaname NOT LIKE 'pg_toast_temp_%'
AND rowsecurity = true
ORDER BY schemaname, tablename;
SELECT
n.nspname AS schemaname,
c.relname AS tablename,
c.relrowsecurity AS rowsecurity,
c.relforcerowsecurity AS rowforced
FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
WHERE
n.nspname NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
AND n.nspname NOT LIKE 'pg_temp_%'
AND n.nspname NOT LIKE 'pg_toast_temp_%'
AND c.relkind = 'r'
AND c.relrowsecurity = true
ORDER BY n.nspname, c.relname;

-- GetRLSPolicies retrieves all row level security policies
-- name: GetRLSPolicies :many
Expand All @@ -705,14 +709,18 @@ ORDER BY schemaname, tablename, policyname;

-- GetRLSTablesForSchema retrieves tables with row level security enabled for a specific schema
-- name: GetRLSTablesForSchema :many
SELECT
schemaname,
tablename
FROM pg_tables
WHERE
schemaname = $1
AND rowsecurity = true
ORDER BY schemaname, tablename;
SELECT
n.nspname AS schemaname,
c.relname AS tablename,
c.relrowsecurity AS rowsecurity,
c.relforcerowsecurity AS rowforced
FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
WHERE
n.nspname = $1
AND c.relkind = 'r'
AND c.relrowsecurity = true
ORDER BY n.nspname, c.relname;

-- GetRLSPoliciesForSchema retrieves all row level security policies for a specific schema
-- name: GetRLSPoliciesForSchema :many
Expand Down
74 changes: 48 additions & 26 deletions ir/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading