Skip to content

Commit 09d8534

Browse files
author
Nathan Fallet
authored
Merge pull request #881 from kasei/custom-aggs
2 parents 9af51e2 + 474f966 commit 09d8534

File tree

4 files changed

+342
-0
lines changed

4 files changed

+342
-0
lines changed

SQLite.xcodeproj/project.pbxproj

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@
8080
19A17FB80B94E882050AA908 /* FoundationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1794CC4D7827E997E32A7 /* FoundationTests.swift */; };
8181
19A17FDA323BAFDEC627E76F /* fixtures in Resources */ = {isa = PBXBuildFile; fileRef = 19A17E2695737FAB5D6086E3 /* fixtures */; };
8282
19A17FF4A10B44D3937C8CAC /* Errors.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1710E73A46D5AC721CDA9 /* Errors.swift */; };
83+
3717F908221F5D8800B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; };
84+
3717F909221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; };
85+
3717F90A221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */; };
8386
3D67B3E61DB2469200A4F4C6 /* libsqlite3.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = 3D67B3E51DB2469200A4F4C6 /* libsqlite3.tbd */; };
8487
3D67B3E71DB246BA00A4F4C6 /* Blob.swift in Sources */ = {isa = PBXBuildFile; fileRef = EE247AEE1C3F06E900AE3E12 /* Blob.swift */; };
8588
3D67B3E81DB246BA00A4F4C6 /* Connection.swift in Sources */ = {isa = PBXBuildFile; fileRef = EE247AEF1C3F06E900AE3E12 /* Connection.swift */; };
@@ -228,6 +231,7 @@
228231
19A17B93B48B5560E6E51791 /* Fixtures.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Fixtures.swift; sourceTree = "<group>"; };
229232
19A17BA55DABB480F9020C8A /* DateAndTimeFunctions.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DateAndTimeFunctions.swift; sourceTree = "<group>"; };
230233
19A17E2695737FAB5D6086E3 /* fixtures */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = folder; path = fixtures; sourceTree = "<group>"; };
234+
3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CustomAggregationTests.swift; sourceTree = "<group>"; };
231235
3D67B3E51DB2469200A4F4C6 /* libsqlite3.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libsqlite3.tbd; path = Platforms/WatchOS.platform/Developer/SDKs/WatchOS3.0.sdk/usr/lib/libsqlite3.tbd; sourceTree = DEVELOPER_DIR; };
232236
3DDC112E26CDBA0200CE369F /* SQLiteObjc.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = SQLiteObjc.h; path = ../SQLiteObjc/include/SQLiteObjc.h; sourceTree = "<group>"; };
233237
49EB68C31F7B3CB400D89D40 /* Coding.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Coding.swift; sourceTree = "<group>"; };
@@ -405,6 +409,7 @@
405409
EE247B1D1C3F137700AE3E12 /* ConnectionTests.swift */,
406410
EE247B1E1C3F137700AE3E12 /* CoreFunctionsTests.swift */,
407411
EE247B1F1C3F137700AE3E12 /* CustomFunctionsTests.swift */,
412+
3717F907221F5D7C00B9BD3D /* CustomAggregationTests.swift */,
408413
EE247B201C3F137700AE3E12 /* ExpressionTests.swift */,
409414
EE247B211C3F137700AE3E12 /* FTS4Tests.swift */,
410415
EE247B2A1C3F141E00AE3E12 /* OperatorsTests.swift */,
@@ -838,6 +843,7 @@
838843
03A65E921C6BB3030062603F /* SetterTests.swift in Sources */,
839844
03A65E891C6BB3030062603F /* ConnectionTests.swift in Sources */,
840845
03A65E8A1C6BB3030062603F /* CoreFunctionsTests.swift in Sources */,
846+
3717F90A221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */,
841847
03A65E931C6BB3030062603F /* StatementTests.swift in Sources */,
842848
03A65E911C6BB3030062603F /* SchemaTests.swift in Sources */,
843849
03A65E8D1C6BB3030062603F /* FTS4Tests.swift in Sources */,
@@ -927,6 +933,7 @@
927933
EE247B271C3F137700AE3E12 /* CustomFunctionsTests.swift in Sources */,
928934
EE247B341C3F142E00AE3E12 /* StatementTests.swift in Sources */,
929935
EE247B301C3F141E00AE3E12 /* RTreeTests.swift in Sources */,
936+
3717F908221F5D8800B9BD3D /* CustomAggregationTests.swift in Sources */,
930937
EE247B231C3F137700AE3E12 /* BlobTests.swift in Sources */,
931938
EE247B351C3F142E00AE3E12 /* ValueTests.swift in Sources */,
932939
EE247B2F1C3F141E00AE3E12 /* QueryTests.swift in Sources */,
@@ -986,6 +993,7 @@
986993
EE247B5F1C3F3FC700AE3E12 /* StatementTests.swift in Sources */,
987994
EE247B5C1C3F3FC700AE3E12 /* RTreeTests.swift in Sources */,
988995
EE247B571C3F3FC700AE3E12 /* CustomFunctionsTests.swift in Sources */,
996+
3717F909221F5D8900B9BD3D /* CustomAggregationTests.swift in Sources */,
989997
EE247B601C3F3FC700AE3E12 /* ValueTests.swift in Sources */,
990998
EE247B551C3F3FC700AE3E12 /* ConnectionTests.swift in Sources */,
991999
EE247B611C3F3FC700AE3E12 /* TestHelpers.swift in Sources */,

Sources/SQLite/Core/Connection.swift

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,8 +636,122 @@ public final class Connection {
636636
if functions[function] == nil { self.functions[function] = [:] }
637637
functions[function]?[argc] = box
638638
}
639+
640+
/// Creates or redefines a custom SQL aggregate.
641+
///
642+
/// - Parameters:
643+
///
644+
/// - aggregate: The name of the aggregate to create or redefine.
645+
///
646+
/// - argumentCount: The number of arguments that the aggregate takes. If
647+
/// `nil`, the aggregate may take any number of arguments.
648+
///
649+
/// Default: `nil`
650+
///
651+
/// - deterministic: Whether or not the aggregate is deterministic (_i.e._
652+
/// the aggregate always returns the same result for a given input).
653+
///
654+
/// Default: `false`
655+
///
656+
/// - step: A block of code to run for each row of an aggregation group.
657+
/// The block is called with an array of raw SQL values mapped to the
658+
/// aggregate’s parameters, and an UnsafeMutablePointer to a state
659+
/// variable.
660+
///
661+
/// - final: A block of code to run after each row of an aggregation group
662+
/// is processed. The block is called with an UnsafeMutablePointer to a
663+
/// state variable, and should return a raw SQL value (or nil).
664+
///
665+
/// - state: A block of code to run to produce a fresh state variable for
666+
/// each aggregation group. The block should return an
667+
/// UnsafeMutablePointer to the fresh state variable.
668+
public func createAggregation<T>(
669+
_ aggregate: String,
670+
argumentCount: UInt? = nil,
671+
deterministic: Bool = false,
672+
step: @escaping ([Binding?], UnsafeMutablePointer<T>) -> (),
673+
final: @escaping (UnsafeMutablePointer<T>) -> Binding?,
674+
state: @escaping () -> UnsafeMutablePointer<T>) {
675+
676+
677+
let argc = argumentCount.map { Int($0) } ?? -1
678+
let box : Aggregate = { (stepFlag: Int, context: OpaquePointer?, argc: Int32, argv: UnsafeMutablePointer<OpaquePointer?>?) in
679+
let ptr = sqlite3_aggregate_context(context, 64)! // needs to be at least as large as uintptr_t; better way to do this?
680+
let p = ptr.assumingMemoryBound(to: UnsafeMutableRawPointer.self)
681+
if stepFlag > 0 {
682+
let arguments: [Binding?] = (0..<Int(argc)).map { idx in
683+
let value = argv![idx]
684+
switch sqlite3_value_type(value) {
685+
case SQLITE_BLOB:
686+
return Blob(bytes: sqlite3_value_blob(value), length: Int(sqlite3_value_bytes(value)))
687+
case SQLITE_FLOAT:
688+
return sqlite3_value_double(value)
689+
case SQLITE_INTEGER:
690+
return sqlite3_value_int64(value)
691+
case SQLITE_NULL:
692+
return nil
693+
case SQLITE_TEXT:
694+
return String(cString: UnsafePointer(sqlite3_value_text(value)))
695+
case let type:
696+
fatalError("unsupported value type: \(type)")
697+
}
698+
}
699+
700+
if ptr.assumingMemoryBound(to: Int64.self).pointee == 0 {
701+
let v = state()
702+
p.pointee = UnsafeMutableRawPointer(mutating: v)
703+
}
704+
step(arguments, p.pointee.assumingMemoryBound(to: T.self))
705+
} else {
706+
let result = final(p.pointee.assumingMemoryBound(to: T.self))
707+
if let result = result as? Blob {
708+
sqlite3_result_blob(context, result.bytes, Int32(result.bytes.count), nil)
709+
} else if let result = result as? Double {
710+
sqlite3_result_double(context, result)
711+
} else if let result = result as? Int64 {
712+
sqlite3_result_int64(context, result)
713+
} else if let result = result as? String {
714+
sqlite3_result_text(context, result, Int32(result.count), SQLITE_TRANSIENT)
715+
} else if result == nil {
716+
sqlite3_result_null(context)
717+
} else {
718+
fatalError("unsupported result type: \(String(describing: result))")
719+
}
720+
}
721+
}
722+
723+
var flags = SQLITE_UTF8
724+
#if !os(Linux)
725+
if deterministic {
726+
flags |= SQLITE_DETERMINISTIC
727+
}
728+
#endif
729+
730+
sqlite3_create_function_v2(
731+
handle,
732+
aggregate,
733+
Int32(argc),
734+
flags,
735+
unsafeBitCast(box, to: UnsafeMutableRawPointer.self),
736+
nil,
737+
{ context, argc, value in
738+
let function = unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)
739+
function(1, context, argc, value)
740+
},
741+
{ context in
742+
let function = unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)
743+
function(0, context, 0, nil)
744+
},
745+
nil
746+
)
747+
if aggregations[aggregate] == nil { self.aggregations[aggregate] = [:] }
748+
aggregations[aggregate]?[argc] = box
749+
}
750+
751+
fileprivate typealias Aggregate = @convention(block) (Int, OpaquePointer?, Int32, UnsafeMutablePointer<OpaquePointer?>?) -> Void
639752
fileprivate typealias Function = @convention(block) (OpaquePointer?, Int32, UnsafeMutablePointer<OpaquePointer?>?) -> Void
640753
fileprivate var functions = [String: [Int: Function]]()
754+
fileprivate var aggregations = [String: [Int: Aggregate]]()
641755

642756
/// Defines a new collating sequence.
643757
///

Sources/SQLite/Typed/CustomFunctions.swift

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,69 @@ public extension Connection {
133133
}
134134
}
135135

136+
// MARK: -
137+
138+
public func createAggregation<T: AnyObject>(
139+
_ aggregate: String,
140+
argumentCount: UInt? = nil,
141+
deterministic: Bool = false,
142+
initialValue: T,
143+
reduce: @escaping (T, [Binding?]) -> T,
144+
result: @escaping (T) -> Binding?
145+
) {
146+
147+
let step: ([Binding?], UnsafeMutablePointer<UnsafeMutableRawPointer>) -> () = { (bindings, ptr) in
148+
let p = ptr.pointee.assumingMemoryBound(to: T.self)
149+
let current = Unmanaged<T>.fromOpaque(p).takeRetainedValue()
150+
let next = reduce(current, bindings)
151+
ptr.pointee = Unmanaged.passRetained(next).toOpaque()
152+
}
153+
154+
let final: (UnsafeMutablePointer<UnsafeMutableRawPointer>) -> Binding? = { (ptr) in
155+
let p = ptr.pointee.assumingMemoryBound(to: T.self)
156+
let obj = Unmanaged<T>.fromOpaque(p).takeRetainedValue()
157+
let value = result(obj)
158+
ptr.deallocate()
159+
return value
160+
}
161+
162+
let state: () -> UnsafeMutablePointer<UnsafeMutableRawPointer> = {
163+
let p = UnsafeMutablePointer<UnsafeMutableRawPointer>.allocate(capacity: 1)
164+
p.pointee = Unmanaged.passRetained(initialValue).toOpaque()
165+
return p
166+
}
167+
168+
createAggregation(aggregate, step: step, final: final, state: state)
169+
}
170+
171+
public func createAggregation<T>(
172+
_ aggregate: String,
173+
argumentCount: UInt? = nil,
174+
deterministic: Bool = false,
175+
initialValue: T,
176+
reduce: @escaping (T, [Binding?]) -> T,
177+
result: @escaping (T) -> Binding?
178+
) {
179+
180+
let step: ([Binding?], UnsafeMutablePointer<T>) -> () = { (bindings, p) in
181+
let current = p.pointee
182+
let next = reduce(current, bindings)
183+
p.pointee = next
184+
}
185+
186+
let final: (UnsafeMutablePointer<T>) -> Binding? = { (p) in
187+
let v = result(p.pointee)
188+
p.deallocate()
189+
return v
190+
}
191+
192+
let state: () -> UnsafeMutablePointer<T> = {
193+
let p = UnsafeMutablePointer<T>.allocate(capacity: 1)
194+
p.pointee = initialValue
195+
return p
196+
}
197+
198+
createAggregation(aggregate, step: step, final: final, state: state)
199+
}
200+
136201
}
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import XCTest
2+
import Foundation
3+
import Dispatch
4+
@testable import SQLite
5+
6+
#if SQLITE_SWIFT_STANDALONE
7+
import sqlite3
8+
#elseif SQLITE_SWIFT_SQLCIPHER
9+
import SQLCipher
10+
#elseif os(Linux)
11+
import CSQLite
12+
#else
13+
import SQLite3
14+
#endif
15+
16+
class CustomAggregationTests : SQLiteTestCase {
17+
override func setUp() {
18+
super.setUp()
19+
CreateUsersTable()
20+
try! InsertUser("Alice", age: 30, admin: true)
21+
try! InsertUser("Bob", age: 25, admin: true)
22+
try! InsertUser("Eve", age: 28, admin: false)
23+
}
24+
25+
func testUnsafeCustomSum() {
26+
let step = { (bindings: [Binding?], state: UnsafeMutablePointer<Int64>) in
27+
if let v = bindings[0] as? Int64 {
28+
state.pointee += v
29+
}
30+
}
31+
32+
let final = { (state: UnsafeMutablePointer<Int64>) -> Binding? in
33+
let v = state.pointee
34+
let p = UnsafeMutableBufferPointer(start: state, count: 1)
35+
p.deallocate()
36+
return v
37+
}
38+
let _ = db.createAggregation("mySUM1", step: step, final: final) {
39+
let v = UnsafeMutableBufferPointer<Int64>.allocate(capacity: 1)
40+
v[0] = 0
41+
return v.baseAddress!
42+
}
43+
let result = try! db.prepare("SELECT mySUM1(age) AS s FROM users")
44+
let i = result.columnNames.index(of: "s")!
45+
for row in result {
46+
let value = row[i] as? Int64
47+
XCTAssertEqual(83, value)
48+
}
49+
}
50+
51+
func testUnsafeCustomSumGrouping() {
52+
let step = { (bindings: [Binding?], state: UnsafeMutablePointer<Int64>) in
53+
if let v = bindings[0] as? Int64 {
54+
state.pointee += v
55+
}
56+
}
57+
let final = { (state: UnsafeMutablePointer<Int64>) -> Binding? in
58+
let v = state.pointee
59+
let p = UnsafeMutableBufferPointer(start: state, count: 1)
60+
p.deallocate()
61+
return v
62+
}
63+
let _ = db.createAggregation("mySUM2", step: step, final: final) {
64+
let v = UnsafeMutableBufferPointer<Int64>.allocate(capacity: 1)
65+
v[0] = 0
66+
return v.baseAddress!
67+
}
68+
let result = try! db.prepare("SELECT mySUM2(age) AS s FROM users GROUP BY admin ORDER BY s")
69+
let i = result.columnNames.index(of: "s")!
70+
let values = result.compactMap { $0[i] as? Int64 }
71+
XCTAssertTrue(values.elementsEqual([28, 55]))
72+
}
73+
74+
func testCustomSum() {
75+
let reduce : (Int64, [Binding?]) -> Int64 = { (last, bindings) in
76+
let v = (bindings[0] as? Int64) ?? 0
77+
return last + v
78+
}
79+
let _ = db.createAggregation("myReduceSUM1", initialValue: Int64(2000), reduce: reduce, result: { $0 })
80+
let result = try! db.prepare("SELECT myReduceSUM1(age) AS s FROM users")
81+
let i = result.columnNames.index(of: "s")!
82+
for row in result {
83+
let value = row[i] as? Int64
84+
XCTAssertEqual(2083, value)
85+
}
86+
}
87+
88+
func testCustomSumGrouping() {
89+
let reduce : (Int64, [Binding?]) -> Int64 = { (last, bindings) in
90+
let v = (bindings[0] as? Int64) ?? 0
91+
return last + v
92+
}
93+
let _ = db.createAggregation("myReduceSUM2", initialValue: Int64(3000), reduce: reduce, result: { $0 })
94+
let result = try! db.prepare("SELECT myReduceSUM2(age) AS s FROM users GROUP BY admin ORDER BY s")
95+
let i = result.columnNames.index(of: "s")!
96+
let values = result.compactMap { $0[i] as? Int64 }
97+
XCTAssertTrue(values.elementsEqual([3028, 3055]))
98+
}
99+
100+
func testCustomStringAgg() {
101+
let initial = String(repeating: " ", count: 64)
102+
let reduce : (String, [Binding?]) -> String = { (last, bindings) in
103+
let v = (bindings[0] as? String) ?? ""
104+
return last + v
105+
}
106+
let _ = db.createAggregation("myReduceSUM3", initialValue: initial, reduce: reduce, result: { $0 })
107+
let result = try! db.prepare("SELECT myReduceSUM3(email) AS s FROM users")
108+
let i = result.columnNames.index(of: "s")!
109+
for row in result {
110+
let value = row[i] as? String
111+
XCTAssertEqual("\(initial)Alice@example.comBob@example.comEve@example.com", value)
112+
}
113+
}
114+
115+
func testCustomObjectSum() {
116+
{
117+
let initial = TestObject(value: 1000)
118+
let reduce : (TestObject, [Binding?]) -> TestObject = { (last, bindings) in
119+
let v = (bindings[0] as? Int64) ?? 0
120+
return TestObject(value: last.value + v)
121+
}
122+
let _ = db.createAggregation("myReduceSUMX", initialValue: initial, reduce: reduce, result: { $0.value })
123+
// end this scope to ensure that the initial value is retained
124+
// by the createAggregation call.
125+
}();
126+
{
127+
XCTAssertEqual(TestObject.inits, 1)
128+
let result = try! db.prepare("SELECT myReduceSUMX(age) AS s FROM users")
129+
let i = result.columnNames.index(of: "s")!
130+
for row in result {
131+
let value = row[i] as? Int64
132+
XCTAssertEqual(1083, value)
133+
}
134+
}()
135+
XCTAssertEqual(TestObject.inits, 4)
136+
XCTAssertEqual(TestObject.deinits, 3) // the initial value is still retained by the aggregate's state block, so deinits is one less than inits
137+
}
138+
}
139+
140+
/// This class is used to test that aggregation state variables
141+
/// can be reference types and are properly memory managed when
142+
/// crossing the Swift<->C boundary multiple times.
143+
class TestObject {
144+
static var inits = 0
145+
static var deinits = 0
146+
147+
var value: Int64
148+
init(value: Int64) {
149+
self.value = value
150+
TestObject.inits += 1
151+
}
152+
deinit {
153+
TestObject.deinits += 1
154+
}
155+
}

0 commit comments

Comments
 (0)