Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 65 additions & 63 deletions ir/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ func (f *postgreSQLFormatter) formatRangeSubselect(subselect *pg_query.RangeSubs
}

// formatExpression formats a general expression
//
// NOTE: Two important expression types for array operations:
// 1. A_Expr: Appears when parsing SQL files directly (e.g., "value = ANY(ARRAY[...])")
// 2. ScalarArrayOpExpr: Appears when fetching view definitions from PostgreSQL via pg_get_viewdef()
//
// PostgreSQL internally converts "IN (...)" to "= ANY(ARRAY[...])" when storing views.
// When we fetch the view definition back via pg_get_viewdef(), it returns ScalarArrayOpExpr nodes.
// Both formatAExpr and formatScalarArrayOpExpr convert "= ANY" back to the cleaner "IN" syntax,
// while preserving other operators (>, <, <>) with ANY/ALL syntax.
func (f *postgreSQLFormatter) formatExpression(expr *pg_query.Node) {
switch {
case expr.GetColumnRef() != nil:
Expand Down Expand Up @@ -312,16 +321,7 @@ func (f *postgreSQLFormatter) formatAExpr(expr *pg_query.A_Expr) {
if isEqualityAny && expr.Rexpr != nil {
if aArrayExpr := expr.Rexpr.GetAArrayExpr(); aArrayExpr != nil {
// Convert "column = ANY(ARRAY[...])" to "column IN (...)"
f.formatExpressionStripCast(expr.Lexpr)
f.buffer.WriteString(" IN (")
for i, elem := range aArrayExpr.Elements {
if i > 0 {
f.buffer.WriteString(", ")
}
// Strip type casts from constants in IN list
f.formatExpressionStripCast(elem)
}
f.buffer.WriteString(")")
f.formatArrayAsIN(expr.Lexpr, aArrayExpr.Elements)
return
}
}
Expand Down Expand Up @@ -367,18 +367,8 @@ func (f *postgreSQLFormatter) formatAExpr(expr *pg_query.A_Expr) {
if len(expr.Name) == 1 && expr.Rexpr != nil {
if str := expr.Name[0].GetString_(); str != nil && str.Sval == "=" {
if aArrayExpr := expr.Rexpr.GetAArrayExpr(); aArrayExpr != nil {
// Direct array comparison: column = ARRAY[...]
// Convert to IN syntax, stripping unnecessary type casts from constants
f.formatExpressionStripCast(expr.Lexpr)
f.buffer.WriteString(" IN (")
for i, elem := range aArrayExpr.Elements {
if i > 0 {
f.buffer.WriteString(", ")
}
// Strip type casts from constants in IN list
f.formatExpressionStripCast(elem)
}
f.buffer.WriteString(")")
// Direct array comparison: column = ARRAY[...] → column IN (...)
f.formatArrayAsIN(expr.Lexpr, aArrayExpr.Elements)
return
}
}
Expand Down Expand Up @@ -689,13 +679,40 @@ func (f *postgreSQLFormatter) formatAArrayExpr(arrayExpr *pg_query.A_ArrayExpr)
f.buffer.WriteString("]")
}

// formatScalarArrayOpExpr formats scalar array operations like "column = ANY (ARRAY[...])"
// and converts them to the simpler "column IN (...)" syntax
func (f *postgreSQLFormatter) formatScalarArrayOpExpr(arrayOp *pg_query.ScalarArrayOpExpr) {
// Check if this is a simple = ANY pattern that can be converted to IN
// UseOr means ANY (disjunction), !UseOr means ALL (conjunction)
// IMPORTANT: We must also verify the operator is equality (=), not other operators like >, <, <>
// formatArrayAsIN is a helper to format "column IN (values)" syntax
// Used by both formatAExpr and formatScalarArrayOpExpr to convert "= ANY(ARRAY[...])" to "IN (...)"
func (f *postgreSQLFormatter) formatArrayAsIN(leftExpr *pg_query.Node, arrayElements []*pg_query.Node) {
// Format left side (the column/expression)
f.formatExpressionStripCast(leftExpr)

f.buffer.WriteString(" IN (")

// Format array elements as comma-separated list, stripping unnecessary type casts
for i, elem := range arrayElements {
if i > 0 {
f.buffer.WriteString(", ")
}
f.formatExpressionStripCast(elem)
}

f.buffer.WriteString(")")
}

// formatScalarArrayOpExpr formats ScalarArrayOpExpr nodes (PostgreSQL's internal array operation representation).
//
// CONTEXT: This function handles a narrow case - formatting view definitions fetched from PostgreSQL
// via pg_get_viewdef(). When PostgreSQL stores views, it converts "IN (...)" to "= ANY(ARRAY[...])"
// internally. When we fetch views back, we get ScalarArrayOpExpr nodes instead of the original A_Expr.
//
// This function converts "= ANY" back to the cleaner "IN (...)" syntax, while preserving
// other operators (>, <, <>, etc.) with their original ANY/ALL syntax.
//
// Example transformations:
// - "value = ANY (ARRAY[1, 2, 3])" → "value IN (1, 2, 3)" (converted)
// - "value > ANY (ARRAY[1, 2, 3])" → "value > ANY (ARRAY[1, 2, 3])" (preserved)
// - "value = ALL (ARRAY[1, 2, 3])" → "value = ALL (ARRAY[1, 2, 3])" (preserved)
func (f *postgreSQLFormatter) formatScalarArrayOpExpr(arrayOp *pg_query.ScalarArrayOpExpr) {
// Validate Args structure
if len(arrayOp.Args) != 2 {
// Malformed expression, use deparse fallback
if deparseResult, err := f.deparseNode(&pg_query.Node{Node: &pg_query.Node_ScalarArrayOpExpr{ScalarArrayOpExpr: arrayOp}}); err == nil {
Expand All @@ -704,69 +721,54 @@ func (f *postgreSQLFormatter) formatScalarArrayOpExpr(arrayOp *pg_query.ScalarAr
return
}

// Get the operator name by deparsing
// Deparse once to extract the operator name
// We need to deparse because ScalarArrayOpExpr doesn't directly expose the operator name
deparsed, err := f.deparseNode(&pg_query.Node{Node: &pg_query.Node_ScalarArrayOpExpr{ScalarArrayOpExpr: arrayOp}})
if err != nil {
// If deparse fails, just return empty (shouldn't happen in practice)
// If deparse fails, silently return (shouldn't happen in practice)
return
}

// Extract the operator once to avoid redundant string parsing
// Extract operator from deparsed string (e.g., "value > ANY (...)" → ">")
opName := extractOperator(deparsed)

// Check if operator is = (equality)
isEqualityOp := opName == "="

// Only convert to IN syntax if it's "= ANY"
if arrayOp.UseOr && isEqualityOp {
// Args[0] is the left side (column), Args[1] is the right side (array)
// Format as "column IN (values)"
// Format left side (the column)
f.formatExpression(arrayOp.Args[0])

f.buffer.WriteString(" IN (")

// Extract values from the array
// Check if this is "= ANY" which can be converted to cleaner "IN" syntax
// - UseOr == true means ANY (disjunction/OR semantics)
// - UseOr == false means ALL (conjunction/AND semantics)
// - Only convert equality with ANY, not other operators or ALL
if arrayOp.UseOr && opName == "=" {
// Convert "column = ANY (ARRAY[...])" → "column IN (...)"
if arrayExpr := arrayOp.Args[1].GetArrayExpr(); arrayExpr != nil {
// Format array elements as comma-separated list
for i, elem := range arrayExpr.Elements {
if i > 0 {
f.buffer.WriteString(", ")
}
f.formatExpression(elem)
}
} else {
// Fallback: format the right expression as-is
f.formatExpression(arrayOp.Args[1])
// Use the shared helper to format as IN syntax
f.formatArrayAsIN(arrayOp.Args[0], arrayExpr.Elements)
return
}

f.buffer.WriteString(")")
return
}

// For other operations (like <> ANY, > ANY, = ALL), format manually
// Format: <left_expr> <op> <ANY|ALL> (<array_expr>)
// For all other operations (<> ANY, > ANY, < ANY, = ALL, etc.), preserve original syntax
// Format: <left_expr> <operator> <ANY|ALL> (<array_expr>)

// Format left side
// Format left side (the column/expression)
f.formatExpression(arrayOp.Args[0])

// Use the already-extracted operator
// Format operator
if opName != "" {
f.buffer.WriteString(" ")
f.buffer.WriteString(opName)
f.buffer.WriteString(" ")
} else {
// Shouldn't happen, but provide fallback
f.buffer.WriteString(" <unknown> ")
}

// Format ANY or ALL
// Format ANY or ALL keyword
if arrayOp.UseOr {
f.buffer.WriteString("ANY (")
} else {
f.buffer.WriteString("ALL (")
}

// Format right side (the array)
// Format right side (the array expression)
f.formatExpression(arrayOp.Args[1])

f.buffer.WriteString(")")
Expand Down