Skip to content

Commit 772c33d

Browse files
committed
add alias and other improvements
1 parent 58c5875 commit 772c33d

File tree

4 files changed

+198
-84
lines changed

4 files changed

+198
-84
lines changed

Sources/PreparedStatementsPostgresNIOClient/main.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ struct MyPreparedStatementOldWay: PostgresPreparedStatement {
2828
}
2929
}
3030

31+
@Statement("")
32+
struct MyOtherPreparedStatement {}
33+
3134
@available(macOS 14.0, *)
3235
@Observable
3336
final class MyObservable {}

Sources/PreparedStatementsPostgresNIOMacros/PreparedStatementsPostgresNIOMacro.swift

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ import SwiftSyntaxMacros
55
import Utility
66

77
public struct PreparedStatementsPostgresNIOMacro: ExtensionMacro, MemberMacro {
8+
private typealias Column = (name: String, type: TokenSyntax, alias: String?)
9+
private typealias Bind = (name: String, type: TokenSyntax)
10+
811
public static func expansion(
912
of node: AttributeSyntax,
1013
attachedTo declaration: some DeclGroupSyntax,
@@ -23,25 +26,26 @@ public struct PreparedStatementsPostgresNIOMacro: ExtensionMacro, MemberMacro {
2326
}
2427

2528
public static func expansion(of node: AttributeSyntax, providingMembersOf declaration: some DeclGroupSyntax, in context: some MacroExpansionContext) throws -> [DeclSyntax] {
26-
guard let elements = node.arguments?.as(LabeledExprListSyntax.self)?
27-
.first?.expression.as(StringLiteralExprSyntax.self)?.segments else {
28-
// TODO: Be more specific about this error
29-
// context.diagnose(Diagnostic(node: Syntax(node), message: PostgresNIODiagnostic.wrongArgument))
30-
return []
31-
}
29+
// It is fine to force unwrap here, because the compiler ensures we receive this exact syntax tree here.
30+
let elements = node
31+
.arguments!.as(LabeledExprListSyntax.self)!
32+
.first!.expression.as(StringLiteralExprSyntax.self)!.segments
3233

3334
var sql = ""
34-
var columns: [(String, TokenSyntax)] = []
35-
var binds: [(String, TokenSyntax)] = []
35+
var columns: [Column] = []
36+
var binds: [Bind] = []
3637
for element in elements {
3738
if let expression = element.as(ExpressionSegmentSyntax.self) {
3839
let interpolation = extractInterpolations(expression)
3940
switch interpolation {
40-
case .column(let name, let type):
41-
columns.append((name, type))
42-
sql.append(name)
43-
case .bind(let name, let type):
44-
binds.append((name, type))
41+
case .column(let column):
42+
columns.append(column)
43+
sql.append(column.name)
44+
if let alias = column.alias {
45+
sql.append(" AS \(alias)")
46+
}
47+
case .bind(let bind):
48+
binds.append(bind)
4549
sql.append("$\(binds.count)")
4650
}
4751
} else if let expression = element.as(StringSegmentSyntax.self) {
@@ -53,14 +57,14 @@ public struct PreparedStatementsPostgresNIOMacro: ExtensionMacro, MemberMacro {
5357
structKeyword: .keyword(.struct, trailingTrivia: .space),
5458
name: .identifier("Row", trailingTrivia: .space),
5559
memberBlockBuilder: {
56-
for (name, type) in columns {
60+
for (name, type, alias) in columns {
5761
MemberBlockItemSyntax(
5862
decl: VariableDeclSyntax(
5963
bindingSpecifier: .keyword(.let, trailingTrivia: .space),
6064
bindings: PatternBindingListSyntax(
6165
itemsBuilder: {
6266
PatternBindingSyntax(
63-
pattern: IdentifierPatternSyntax(identifier: .identifier(name)),
67+
pattern: IdentifierPatternSyntax(identifier: .identifier(alias ?? name)),
6468
typeAnnotation: TypeAnnotationSyntax(type: IdentifierTypeSyntax(name: type))
6569
)
6670
}
@@ -108,42 +112,35 @@ public struct PreparedStatementsPostgresNIOMacro: ExtensionMacro, MemberMacro {
108112
]
109113
}
110114

111-
enum Interpolation {
112-
case column(String, TokenSyntax)
113-
case bind(String, TokenSyntax)
115+
private enum Interpolation {
116+
case column(Column)
117+
case bind(Bind)
114118
}
115119
private static func extractInterpolations(_ node: ExpressionSegmentSyntax) -> Interpolation {
116120
let tupleElements = node.expressions
117-
guard tupleElements.count == 2 else {
118-
fatalError("Expected tuple with exactly two elements")
119-
}
121+
precondition(tupleElements.count >= 2, "Expected tuple with two or more elements, less are impossible as the compiler already checks for it")
120122

121123
// First element needs to be the column name
122124
var iterator = tupleElements.makeIterator()
123-
let identifier = iterator.next()! as LabeledExprSyntax // works as tuple contains exactly two elements
124-
guard let type = iterator.next()!.expression.as(MemberAccessExprSyntax.self)?.base?.as(DeclReferenceExprSyntax.self) else {
125-
fatalError("expected something")
126-
}
125+
let identifier = iterator.next()! as LabeledExprSyntax // works as tuple contains at least two elements
126+
// Type can be force-unwrapped as the compiler ensures it is there.
127+
let type = iterator.next()!.expression.as(MemberAccessExprSyntax.self)!
128+
.base!.as(DeclReferenceExprSyntax.self)!
129+
// Same thing as with type.
130+
let name = identifier.expression.as(StringLiteralExprSyntax.self)!
131+
.segments.first!.as(StringSegmentSyntax.self)!.content.text
127132
switch identifier.label?.identifier?.name {
128133
case "bind":
129-
guard let columnName = identifier.expression.as(StringLiteralExprSyntax.self)?
130-
.segments.first?.as(StringSegmentSyntax.self)?.content
131-
.text else {
132-
fatalError("Expected column name")
133-
}
134-
return .bind(columnName, type.baseName)
134+
return .bind((name: name, type: type.baseName))
135135
default:
136-
guard let columnName = identifier.expression.as(StringLiteralExprSyntax.self)?
137-
.segments.first?.as(StringSegmentSyntax.self)?.content
138-
.text else {
139-
fatalError("Expected column name")
140-
}
136+
let alias = iterator.next()?.expression.as(StringLiteralExprSyntax.self)?
137+
.segments.first?.as(StringSegmentSyntax.self)?.content.text
141138

142-
return .column(columnName, type.baseName)
139+
return .column((name: name, type: type.baseName, alias: alias))
143140
}
144141
}
145142

146-
private static func makeBindings(for binds: [(String, TokenSyntax)]) -> FunctionDeclSyntax {
143+
private static func makeBindings(for binds: [Bind]) -> FunctionDeclSyntax {
147144
FunctionDeclSyntax(
148145
name: .identifier("makeBindings"),
149146
signature: FunctionSignatureSyntax(
@@ -218,7 +215,7 @@ public struct PreparedStatementsPostgresNIOMacro: ExtensionMacro, MemberMacro {
218215
)
219216
}
220217

221-
private static func decodeRow(from columns: [(String, TokenSyntax)]) -> FunctionDeclSyntax {
218+
private static func decodeRow(from columns: [Column]) -> FunctionDeclSyntax {
222219
FunctionDeclSyntax(
223220
name: .identifier("decodeRow"),
224221
signature: FunctionSignatureSyntax(
@@ -240,8 +237,8 @@ public struct PreparedStatementsPostgresNIOMacro: ExtensionMacro, MemberMacro {
240237
bindings: [
241238
PatternBindingSyntax(
242239
pattern: TuplePatternSyntax(elementsBuilder: {
243-
for (column, _) in columns {
244-
TuplePatternElementSyntax(pattern: IdentifierPatternSyntax(identifier: .identifier(column)))
240+
for (column, _, alias) in columns {
241+
TuplePatternElementSyntax(pattern: IdentifierPatternSyntax(identifier: .identifier(alias ?? column)))
245242
}
246243
}),
247244
initializer: InitializerClauseSyntax(
@@ -256,7 +253,7 @@ public struct PreparedStatementsPostgresNIOMacro: ExtensionMacro, MemberMacro {
256253
argumentsBuilder: {
257254
LabeledExprSyntax(expression: MemberAccessExprSyntax(
258255
base: TupleExprSyntax(elementsBuilder: {
259-
for (_, column) in columns {
256+
for (_, column, _) in columns {
260257
LabeledExprSyntax(expression: DeclReferenceExprSyntax(baseName: column))
261258
}
262259
}),
@@ -276,8 +273,11 @@ public struct PreparedStatementsPostgresNIOMacro: ExtensionMacro, MemberMacro {
276273
leftParen: .leftParenToken(),
277274
rightParen: .rightParenToken(),
278275
argumentsBuilder: {
279-
for (column, _) in columns {
280-
LabeledExprSyntax(label: column, expression: DeclReferenceExprSyntax(baseName: .identifier(column)))
276+
for (column, _, alias) in columns {
277+
LabeledExprSyntax(
278+
label: alias ?? column,
279+
expression: DeclReferenceExprSyntax(baseName: .identifier(alias ?? column))
280+
)
281281
}
282282
}
283283
)))))

Sources/Utility/STMT.swift

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,51 +7,52 @@
77

88
import PostgresNIO
99

10+
/// A parsable String literal for the `@Statement` macro. It doesn't store anything and is completely useless outside of the `@Statement` declaration.
11+
///
12+
/// ```swift
13+
/// @Statement("SELECT \("id", UUID.self), \("name", String.self), \("age", Int.self) FROM users")
14+
/// struct UsersStatement {}
15+
/// ```
1016
public struct _PostgresPreparedStatement: ExpressibleByStringInterpolation {
11-
public let sql: String
12-
let columns: [(name: String, type: String)]
13-
let binds: [(name: String, type: String)]
14-
15-
public init(stringLiteral value: String) {
16-
sql = value
17-
columns = []
18-
binds = []
19-
}
17+
public init(stringLiteral value: String) {}
2018

21-
public init(stringInterpolation: StringInterpolation) {
22-
self.sql = stringInterpolation.sql
23-
self.columns = stringInterpolation.columns
24-
self.binds = stringInterpolation.binds
25-
}
19+
public init(stringInterpolation: StringInterpolation) {}
2620

2721
public struct StringInterpolation: StringInterpolationProtocol {
2822
public typealias StringLiteralType = String
2923

30-
var sql: String
31-
var columns: [(name: String, type: String)]
32-
var binds: [(name: String, type: String)]
33-
34-
public init(literalCapacity: Int, interpolationCount: Int) {
35-
sql = ""
36-
sql.reserveCapacity(literalCapacity)
37-
columns = []
38-
columns.reserveCapacity(interpolationCount)
39-
binds = []
40-
binds.reserveCapacity(interpolationCount)
41-
}
42-
43-
public mutating func appendLiteral(_ literal: String) {
44-
sql.append(literal)
45-
}
46-
47-
public mutating func appendInterpolation<T: PostgresDecodable>(_ name: String, _ type: T.Type) {
48-
sql.append(name)
49-
columns.append((name, String(reflecting: type)))
50-
}
51-
52-
public mutating func appendInterpolation<T: PostgresDynamicTypeEncodable>(bind: String, _ type: T.Type) {
53-
binds.append((bind, String(reflecting: type)))
54-
sql.append("$\(binds.count)")
55-
}
24+
public init(literalCapacity: Int, interpolationCount: Int) {}
25+
26+
public mutating func appendLiteral(_ literal: String) {}
27+
28+
/// Adds a column, e.g. inside a `SELECT` statement.
29+
/// - Parameters:
30+
/// - name: The column name in SQL.
31+
/// - type: The type used to represent the column data in Swift.
32+
/// - as: An optional alias for the column. It will be used in as an alias in SQL and the declaration Swifts `Row` struct.
33+
///
34+
/// ```swift
35+
///"SELECT \("id", UUID.self) FROM users"
36+
///// SQL: SELECT id FROM users
37+
///// Swift: struct Row { let id: UUID }
38+
///
39+
///"SELECT \("user_id", UUID.self, as: userID) FROM users"
40+
///// SQL: SELECT id as userID FROM users
41+
///// SWIFT: struct Row { let userID: UUID }
42+
/// ```
43+
public mutating func appendInterpolation(
44+
_ name: String,
45+
_ type: (some PostgresDecodable).Type,
46+
as: String? = nil
47+
) {}
48+
49+
/// Adds a bind variable.
50+
/// - Parameters:
51+
/// - bind: The name of the bind variable in Swift.
52+
/// - type: The Swift type of the bind variable.
53+
public mutating func appendInterpolation(
54+
bind: String,
55+
_ type: (some PostgresDynamicTypeEncodable).Type
56+
) {}
5657
}
5758
}

Tests/PreparedStatementsPostgresNIOTests/PreparedStatementsPostgresNIOTests.swift

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,114 @@ final class PreparedStatementsPostgresNIOTests: XCTestCase {
142142
throw XCTSkip("macros are only supported when running tests for the host platform")
143143
#endif
144144
}
145+
146+
func testMacroWithAliasInColumn() throws {
147+
#if canImport(PreparedStatementsPostgresNIOMacros)
148+
assertMacroExpansion(
149+
#"""
150+
@Statement("SELECT \("user_id", UUID.self, as: "userID"), \("name", String.self), \("age", Int.self) FROM users WHERE \(bind: "age", Int.self) > age")
151+
struct MyStatement {}
152+
"""#,
153+
expandedSource: """
154+
struct MyStatement {
155+
156+
struct Row {
157+
let userID: UUID
158+
let name: String
159+
let age: Int
160+
}
161+
162+
static let sql = "SELECT user_id AS userID, name, age FROM users WHERE $1 > age"
163+
164+
let age: Int
165+
166+
func makeBindings() throws -> PostgresBindings {
167+
var bindings = PostgresBindings(capacity: 1)
168+
bindings.append(age)
169+
return bindings
170+
}
171+
172+
func decodeRow(_ row: PostgresRow) throws -> Row {
173+
let (userID, name, age) = try row.decode((UUID, String, Int).self)
174+
return Row(userID: userID, name: name, age: age)
175+
}
176+
}
177+
178+
extension MyStatement: PostgresPreparedStatement {
179+
}
180+
""",
181+
macroSpecs: testMacros
182+
)
183+
#else
184+
throw XCTSkip("macros are only supported when running tests for the host platform")
185+
#endif
186+
}
187+
188+
func testMacroWithoutAnything() throws {
189+
#if canImport(PreparedStatementsPostgresNIOMacros)
190+
assertMacroExpansion(
191+
#"""
192+
@Statement("SELECT id, name, age FROM users")
193+
struct MyStatement {}
194+
"""#,
195+
expandedSource: """
196+
struct MyStatement {
197+
198+
struct Row {
199+
}
200+
201+
static let sql = "SELECT id, name, age FROM users"
202+
203+
func makeBindings() throws -> PostgresBindings {
204+
return PostgresBindings()
205+
}
206+
207+
func decodeRow(_ row: PostgresRow) throws -> Row {
208+
return Row()
209+
}
210+
}
211+
212+
extension MyStatement: PostgresPreparedStatement {
213+
}
214+
""",
215+
macroSpecs: testMacros
216+
)
217+
#else
218+
throw XCTSkip("macros are only supported when running tests for the host platform")
219+
#endif
220+
}
221+
222+
func testMacroWithEmptyString() throws {
223+
#if canImport(PreparedStatementsPostgresNIOMacros)
224+
assertMacroExpansion(
225+
#"""
226+
@Statement("")
227+
struct MyStatement {}
228+
"""#,
229+
expandedSource: """
230+
struct MyStatement {
231+
232+
struct Row {
233+
}
234+
235+
static let sql = ""
236+
237+
func makeBindings() throws -> PostgresBindings {
238+
return PostgresBindings()
239+
}
240+
241+
func decodeRow(_ row: PostgresRow) throws -> Row {
242+
return Row()
243+
}
244+
}
245+
246+
extension MyStatement: PostgresPreparedStatement {
247+
}
248+
""",
249+
macroSpecs: testMacros
250+
)
251+
#else
252+
throw XCTSkip("macros are only supported when running tests for the host platform")
253+
#endif
254+
}
145255
}

0 commit comments

Comments
 (0)