|
| 1 | +package format |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "strings" |
| 6 | + |
| 7 | + "github.com/sqlc-dev/doubleclick/ast" |
| 8 | +) |
| 9 | + |
| 10 | +// Expression formats an expression. |
| 11 | +func Expression(sb *strings.Builder, expr ast.Expression) { |
| 12 | + if expr == nil { |
| 13 | + return |
| 14 | + } |
| 15 | + |
| 16 | + switch e := expr.(type) { |
| 17 | + case *ast.Literal: |
| 18 | + formatLiteral(sb, e) |
| 19 | + case *ast.Identifier: |
| 20 | + formatIdentifier(sb, e) |
| 21 | + case *ast.TableIdentifier: |
| 22 | + formatTableIdentifier(sb, e) |
| 23 | + case *ast.FunctionCall: |
| 24 | + formatFunctionCall(sb, e) |
| 25 | + case *ast.BinaryExpr: |
| 26 | + formatBinaryExpr(sb, e) |
| 27 | + case *ast.UnaryExpr: |
| 28 | + formatUnaryExpr(sb, e) |
| 29 | + case *ast.Asterisk: |
| 30 | + formatAsterisk(sb, e) |
| 31 | + case *ast.AliasedExpr: |
| 32 | + formatAliasedExpr(sb, e) |
| 33 | + default: |
| 34 | + // Fallback for unhandled expressions |
| 35 | + sb.WriteString(fmt.Sprintf("%v", expr)) |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +// formatLiteral formats a literal value. |
| 40 | +func formatLiteral(sb *strings.Builder, lit *ast.Literal) { |
| 41 | + switch lit.Type { |
| 42 | + case ast.LiteralString: |
| 43 | + sb.WriteString("'") |
| 44 | + // Escape single quotes in the string |
| 45 | + s := lit.Value.(string) |
| 46 | + s = strings.ReplaceAll(s, "'", "''") |
| 47 | + sb.WriteString(s) |
| 48 | + sb.WriteString("'") |
| 49 | + case ast.LiteralInteger: |
| 50 | + switch v := lit.Value.(type) { |
| 51 | + case int64: |
| 52 | + sb.WriteString(fmt.Sprintf("%d", v)) |
| 53 | + case uint64: |
| 54 | + sb.WriteString(fmt.Sprintf("%d", v)) |
| 55 | + default: |
| 56 | + sb.WriteString(fmt.Sprintf("%v", lit.Value)) |
| 57 | + } |
| 58 | + case ast.LiteralFloat: |
| 59 | + sb.WriteString(fmt.Sprintf("%v", lit.Value)) |
| 60 | + case ast.LiteralBoolean: |
| 61 | + if lit.Value.(bool) { |
| 62 | + sb.WriteString("true") |
| 63 | + } else { |
| 64 | + sb.WriteString("false") |
| 65 | + } |
| 66 | + case ast.LiteralNull: |
| 67 | + sb.WriteString("NULL") |
| 68 | + case ast.LiteralArray: |
| 69 | + formatArrayLiteral(sb, lit.Value) |
| 70 | + case ast.LiteralTuple: |
| 71 | + formatTupleLiteral(sb, lit.Value) |
| 72 | + default: |
| 73 | + sb.WriteString(fmt.Sprintf("%v", lit.Value)) |
| 74 | + } |
| 75 | +} |
| 76 | + |
| 77 | +// formatArrayLiteral formats an array literal. |
| 78 | +func formatArrayLiteral(sb *strings.Builder, val interface{}) { |
| 79 | + sb.WriteString("[") |
| 80 | + exprs, ok := val.([]ast.Expression) |
| 81 | + if ok { |
| 82 | + for i, e := range exprs { |
| 83 | + if i > 0 { |
| 84 | + sb.WriteString(", ") |
| 85 | + } |
| 86 | + Expression(sb, e) |
| 87 | + } |
| 88 | + } |
| 89 | + sb.WriteString("]") |
| 90 | +} |
| 91 | + |
| 92 | +// formatTupleLiteral formats a tuple literal. |
| 93 | +func formatTupleLiteral(sb *strings.Builder, val interface{}) { |
| 94 | + sb.WriteString("(") |
| 95 | + exprs, ok := val.([]ast.Expression) |
| 96 | + if ok { |
| 97 | + for i, e := range exprs { |
| 98 | + if i > 0 { |
| 99 | + sb.WriteString(", ") |
| 100 | + } |
| 101 | + Expression(sb, e) |
| 102 | + } |
| 103 | + } |
| 104 | + sb.WriteString(")") |
| 105 | +} |
| 106 | + |
| 107 | +// formatIdentifier formats an identifier. |
| 108 | +func formatIdentifier(sb *strings.Builder, id *ast.Identifier) { |
| 109 | + sb.WriteString(id.Name()) |
| 110 | +} |
| 111 | + |
| 112 | +// formatTableIdentifier formats a table identifier. |
| 113 | +func formatTableIdentifier(sb *strings.Builder, t *ast.TableIdentifier) { |
| 114 | + if t.Database != "" { |
| 115 | + sb.WriteString(t.Database) |
| 116 | + sb.WriteString(".") |
| 117 | + } |
| 118 | + sb.WriteString(t.Table) |
| 119 | +} |
| 120 | + |
| 121 | +// formatFunctionCall formats a function call. |
| 122 | +func formatFunctionCall(sb *strings.Builder, fn *ast.FunctionCall) { |
| 123 | + sb.WriteString(fn.Name) |
| 124 | + sb.WriteString("(") |
| 125 | + if fn.Distinct { |
| 126 | + sb.WriteString("DISTINCT ") |
| 127 | + } |
| 128 | + for i, arg := range fn.Arguments { |
| 129 | + if i > 0 { |
| 130 | + sb.WriteString(", ") |
| 131 | + } |
| 132 | + Expression(sb, arg) |
| 133 | + } |
| 134 | + sb.WriteString(")") |
| 135 | +} |
| 136 | + |
| 137 | +// formatBinaryExpr formats a binary expression. |
| 138 | +func formatBinaryExpr(sb *strings.Builder, expr *ast.BinaryExpr) { |
| 139 | + Expression(sb, expr.Left) |
| 140 | + sb.WriteString(" ") |
| 141 | + sb.WriteString(expr.Op) |
| 142 | + sb.WriteString(" ") |
| 143 | + Expression(sb, expr.Right) |
| 144 | +} |
| 145 | + |
| 146 | +// formatUnaryExpr formats a unary expression. |
| 147 | +func formatUnaryExpr(sb *strings.Builder, expr *ast.UnaryExpr) { |
| 148 | + sb.WriteString(expr.Op) |
| 149 | + Expression(sb, expr.Operand) |
| 150 | +} |
| 151 | + |
| 152 | +// formatAsterisk formats an asterisk. |
| 153 | +func formatAsterisk(sb *strings.Builder, a *ast.Asterisk) { |
| 154 | + if a.Table != "" { |
| 155 | + sb.WriteString(a.Table) |
| 156 | + sb.WriteString(".") |
| 157 | + } |
| 158 | + sb.WriteString("*") |
| 159 | +} |
| 160 | + |
| 161 | +// formatAliasedExpr formats an aliased expression. |
| 162 | +func formatAliasedExpr(sb *strings.Builder, a *ast.AliasedExpr) { |
| 163 | + Expression(sb, a.Expr) |
| 164 | + sb.WriteString(" AS ") |
| 165 | + sb.WriteString(a.Alias) |
| 166 | +} |
0 commit comments