Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 122 additions & 18 deletions internal/diff/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,68 @@ func generateCreateFunctionsSQL(functions []*ir.Function, targetSchema string, c
// generateModifyFunctionsSQL generates ALTER FUNCTION statements
func generateModifyFunctionsSQL(diffs []*functionDiff, targetSchema string, collector *diffCollector) {
for _, diff := range diffs {
sql := generateFunctionSQL(diff.New, targetSchema)

// Create context for this statement
context := &diffContext{
Type: DiffTypeFunction,
Operation: DiffOperationAlter,
Path: fmt.Sprintf("%s.%s", diff.New.Schema, diff.New.Name),
Source: diff,
CanRunInTransaction: true,
oldFunc := diff.Old
newFunc := diff.New

// Check if only LEAKPROOF or PARALLEL attributes changed (not the function body/definition)
onlyAttributesChanged := functionsEqualExceptAttributes(oldFunc, newFunc)

if onlyAttributesChanged {
// Generate ALTER FUNCTION statements for attribute-only changes
// Check PARALLEL changes
if oldFunc.Parallel != newFunc.Parallel {
stmt := fmt.Sprintf("ALTER FUNCTION %s(%s) PARALLEL %s;",
qualifyEntityName(newFunc.Schema, newFunc.Name, targetSchema),
newFunc.GetArguments(),
newFunc.Parallel)

context := &diffContext{
Type: DiffTypeFunction,
Operation: DiffOperationAlter,
Path: fmt.Sprintf("%s.%s", newFunc.Schema, newFunc.Name),
Source: diff,
CanRunInTransaction: true,
}
collector.collect(context, stmt)
}

// Check LEAKPROOF changes
if oldFunc.IsLeakproof != newFunc.IsLeakproof {
var stmt string
if newFunc.IsLeakproof {
stmt = fmt.Sprintf("ALTER FUNCTION %s(%s) LEAKPROOF;",
qualifyEntityName(newFunc.Schema, newFunc.Name, targetSchema),
newFunc.GetArguments())
} else {
stmt = fmt.Sprintf("ALTER FUNCTION %s(%s) NOT LEAKPROOF;",
qualifyEntityName(newFunc.Schema, newFunc.Name, targetSchema),
newFunc.GetArguments())
}

context := &diffContext{
Type: DiffTypeFunction,
Operation: DiffOperationAlter,
Path: fmt.Sprintf("%s.%s", newFunc.Schema, newFunc.Name),
Source: diff,
CanRunInTransaction: true,
}
collector.collect(context, stmt)
}
} else {
// Function body or other attributes changed - use CREATE OR REPLACE
sql := generateFunctionSQL(newFunc, targetSchema)

// Create context for this statement
context := &diffContext{
Type: DiffTypeFunction,
Operation: DiffOperationAlter,
Path: fmt.Sprintf("%s.%s", newFunc.Schema, newFunc.Name),
Source: diff,
CanRunInTransaction: true,
}

collector.collect(context, sql)
}

collector.collect(context, sql)
}
}

Expand Down Expand Up @@ -120,13 +170,6 @@ func generateFunctionSQL(function *ir.Function, targetSchema string) string {
stmt.WriteString(fmt.Sprintf("\nLANGUAGE %s", function.Language))
}

// Add security definer/invoker - PostgreSQL default is INVOKER
if function.IsSecurityDefiner {
stmt.WriteString("\nSECURITY DEFINER")
} else {
stmt.WriteString("\nSECURITY INVOKER")
}

// Add volatility if not default
if function.Volatility != "" {
stmt.WriteString(fmt.Sprintf("\n%s", function.Volatility))
Expand All @@ -137,6 +180,25 @@ func generateFunctionSQL(function *ir.Function, targetSchema string) string {
stmt.WriteString("\nSTRICT")
}

// Add SECURITY DEFINER if true (INVOKER is default and not output)
if function.IsSecurityDefiner {
stmt.WriteString("\nSECURITY DEFINER")
}

// Add LEAKPROOF if true
if function.IsLeakproof {
stmt.WriteString("\nLEAKPROOF")
}
// Note: Don't output NOT LEAKPROOF (it's the default)

// Add PARALLEL if not default (UNSAFE)
if function.Parallel == "SAFE" {
stmt.WriteString("\nPARALLEL SAFE")
} else if function.Parallel == "RESTRICTED" {
stmt.WriteString("\nPARALLEL RESTRICTED")
}
// Note: Don't output PARALLEL UNSAFE (it's the default)

// Add the function body
if function.Definition != "" {
// Check if this uses RETURN clause syntax (PG14+)
Expand Down Expand Up @@ -232,6 +294,42 @@ func formatFunctionParameter(param *ir.Parameter, includeDefault bool, targetSch
return part
}

// functionsEqualExceptAttributes compares two functions ignoring LEAKPROOF and PARALLEL attributes
// Used to determine if ALTER FUNCTION can be used instead of CREATE OR REPLACE
func functionsEqualExceptAttributes(old, new *ir.Function) bool {
if old.Schema != new.Schema {
return false
}
if old.Name != new.Name {
return false
}
if old.Definition != new.Definition {
return false
}
if old.ReturnType != new.ReturnType {
return false
}
if old.Language != new.Language {
return false
}
if old.Volatility != new.Volatility {
return false
}
if old.IsStrict != new.IsStrict {
return false
}
if old.IsSecurityDefiner != new.IsSecurityDefiner {
return false
}
// Note: We intentionally do NOT compare IsLeakproof or Parallel here
// That's the whole point - we want to detect when only those attributes changed

// Compare using normalized Parameters array
oldInputParams := filterNonTableParameters(old.Parameters)
newInputParams := filterNonTableParameters(new.Parameters)
return parametersEqual(oldInputParams, newInputParams)
}

// functionsEqual compares two functions for equality
func functionsEqual(old, new *ir.Function) bool {
if old.Schema != new.Schema {
Expand All @@ -258,6 +356,12 @@ func functionsEqual(old, new *ir.Function) bool {
if old.IsSecurityDefiner != new.IsSecurityDefiner {
return false
}
if old.IsLeakproof != new.IsLeakproof {
return false
}
if old.Parallel != new.Parallel {
return false
}

// Compare using normalized Parameters array
// This ensures type aliases like "character varying" vs "varchar" are treated as equal
Expand Down
19 changes: 19 additions & 0 deletions ir/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,23 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema
// Handle security definer
isSecurityDefiner := fn.IsSecurityDefiner

// Handle leakproof
isLeakproof := fn.IsLeakproof

// Handle parallel mode
parallelMode := ""
proparallel := i.safeInterfaceToString(fn.ParallelMode)
switch proparallel {
case "s":
parallelMode = "SAFE"
case "r":
parallelMode = "RESTRICTED"
case "u":
parallelMode = "UNSAFE"
default:
parallelMode = "UNSAFE" // Defensive default
}

// Parse parameters from the complete signature provided by pg_get_function_arguments()
// This signature includes all parameter information including modes, names, types, and defaults
parameters := i.parseParametersFromSignature(signature, schemaName)
Expand All @@ -928,6 +945,8 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema
Volatility: volatility,
IsStrict: isStrict,
IsSecurityDefiner: isSecurityDefiner,
IsLeakproof: isLeakproof,
Parallel: parallelMode,
}

dbSchema.SetFunction(functionName, function)
Expand Down
2 changes: 2 additions & 0 deletions ir/ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ type Function struct {
Volatility string `json:"volatility,omitempty"` // IMMUTABLE, STABLE, VOLATILE
IsStrict bool `json:"is_strict,omitempty"` // STRICT or null behavior
IsSecurityDefiner bool `json:"is_security_definer,omitempty"` // SECURITY DEFINER
IsLeakproof bool `json:"is_leakproof,omitempty"` // LEAKPROOF
Parallel string `json:"parallel,omitempty"` // SAFE, UNSAFE, RESTRICTED
}

// GetArguments returns the function arguments string (types only) for function identification.
Expand Down
4 changes: 3 additions & 1 deletion ir/queries/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,9 @@ SELECT
ELSE NULL
END AS volatility,
p.proisstrict AS is_strict,
p.prosecdef AS is_security_definer
p.prosecdef AS is_security_definer,
p.proleakproof AS is_leakproof,
p.proparallel AS parallel_mode
FROM information_schema.routines r
LEFT JOIN pg_proc p ON p.proname = r.routine_name
AND p.pronamespace = (SELECT oid FROM pg_namespace WHERE nspname = r.routine_schema)
Expand Down
8 changes: 7 additions & 1 deletion ir/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 23 additions & 5 deletions testdata/diff/create_function/add_function/diff.sql
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
CREATE OR REPLACE FUNCTION days_since_special_date()
RETURNS SETOF timestamp with time zone
CREATE OR REPLACE FUNCTION calculate_tax(
amount numeric,
rate numeric
)
RETURNS numeric
LANGUAGE sql
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT amount * rate;
$$;

CREATE OR REPLACE FUNCTION mask_sensitive_data(
input text
)
RETURNS text
LANGUAGE sql
SECURITY INVOKER
STABLE
RETURN generate_series((date_trunc('day'::text, '2025-01-01 00:00:00'::timestamp without time zone))::timestamp with time zone, date_trunc('day'::text, now()), '1 day'::interval);
LEAKPROOF
AS $$
SELECT '***' || substring(input from 4);
$$;

CREATE OR REPLACE FUNCTION process_order(
order_id integer,
Expand All @@ -16,9 +32,11 @@ CREATE OR REPLACE FUNCTION process_order(
)
RETURNS numeric
LANGUAGE plpgsql
SECURITY DEFINER
VOLATILE
STRICT
SECURITY DEFINER
LEAKPROOF
PARALLEL RESTRICTED
AS $$
DECLARE
total numeric;
Expand Down
28 changes: 23 additions & 5 deletions testdata/diff/create_function/add_function/new.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-- Complex function demonstrating all qualifiers
CREATE FUNCTION process_order(
order_id integer,
-- Simple numeric defaults
Expand All @@ -12,9 +13,11 @@ CREATE FUNCTION process_order(
)
RETURNS numeric
LANGUAGE plpgsql
SECURITY DEFINER
VOLATILE
STRICT
SECURITY DEFINER
LEAKPROOF
PARALLEL RESTRICTED
AS $$
DECLARE
total numeric;
Expand All @@ -24,7 +27,22 @@ BEGIN
END;
$$;

-- Table function with RETURN clause (bug report test case)
CREATE FUNCTION days_since_special_date() RETURNS SETOF timestamptz
LANGUAGE sql STABLE PARALLEL SAFE
RETURN generate_series(date_trunc('day', '2025-01-01'::timestamp), date_trunc('day', NOW()), '1 day'::interval);
-- Function testing PARALLEL SAFE only
CREATE FUNCTION calculate_tax(amount numeric, rate numeric)
RETURNS numeric
LANGUAGE sql
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT amount * rate;
$$;

-- Function testing LEAKPROOF only
CREATE FUNCTION mask_sensitive_data(input text)
RETURNS text
LANGUAGE sql
STABLE
LEAKPROOF
AS $$
SELECT '***' || substring(input from 4);
$$;
14 changes: 10 additions & 4 deletions testdata/diff/create_function/add_function/plan.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"version": "1.0.0",
"pgschema_version": "1.4.0",
"pgschema_version": "1.4.3",
"created_at": "1970-01-01T00:00:00Z",
"source_fingerprint": {
"hash": "965b1131737c955e24c7f827c55bd78e4cb49a75adfd04229e0ba297376f5085"
Expand All @@ -9,13 +9,19 @@
{
"steps": [
{
"sql": "CREATE OR REPLACE FUNCTION days_since_special_date()\nRETURNS SETOF timestamp with time zone\nLANGUAGE sql\nSECURITY INVOKER\nSTABLE\nRETURN generate_series((date_trunc('day'::text, '2025-01-01 00:00:00'::timestamp without time zone))::timestamp with time zone, date_trunc('day'::text, now()), '1 day'::interval);",
"sql": "CREATE OR REPLACE FUNCTION calculate_tax(\n amount numeric,\n rate numeric\n)\nRETURNS numeric\nLANGUAGE sql\nIMMUTABLE\nPARALLEL SAFE\nAS $$\n SELECT amount * rate;\n$$;",
"type": "function",
"operation": "create",
"path": "public.days_since_special_date"
"path": "public.calculate_tax"
},
{
"sql": "CREATE OR REPLACE FUNCTION process_order(\n order_id integer,\n discount_percent numeric DEFAULT 0,\n priority_level integer DEFAULT 1,\n note varchar DEFAULT '',\n status text DEFAULT 'pending',\n apply_tax boolean DEFAULT true,\n is_priority boolean DEFAULT false\n)\nRETURNS numeric\nLANGUAGE plpgsql\nSECURITY DEFINER\nVOLATILE\nSTRICT\nAS $$\nDECLARE\n total numeric;\nBEGIN\n SELECT amount INTO total FROM orders WHERE id = order_id;\n RETURN total - (total * discount_percent / 100);\nEND;\n$$;",
"sql": "CREATE OR REPLACE FUNCTION mask_sensitive_data(\n input text\n)\nRETURNS text\nLANGUAGE sql\nSTABLE\nLEAKPROOF\nAS $$\n SELECT '***' || substring(input from 4);\n$$;",
"type": "function",
"operation": "create",
"path": "public.mask_sensitive_data"
},
{
"sql": "CREATE OR REPLACE FUNCTION process_order(\n order_id integer,\n discount_percent numeric DEFAULT 0,\n priority_level integer DEFAULT 1,\n note varchar DEFAULT '',\n status text DEFAULT 'pending',\n apply_tax boolean DEFAULT true,\n is_priority boolean DEFAULT false\n)\nRETURNS numeric\nLANGUAGE plpgsql\nVOLATILE\nSTRICT\nSECURITY DEFINER\nLEAKPROOF\nPARALLEL RESTRICTED\nAS $$\nDECLARE\n total numeric;\nBEGIN\n SELECT amount INTO total FROM orders WHERE id = order_id;\n RETURN total - (total * discount_percent / 100);\nEND;\n$$;",
"type": "function",
"operation": "create",
"path": "public.process_order"
Expand Down
Loading