Skip to content

Clean up aggregation #1072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .swiftlint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ line_length:
ignores_comments: true

file_length:
warning: 900
error: 900
warning: 500
error: 500
ignore_comment_only_lines: true
8 changes: 8 additions & 0 deletions SQLite.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
03A65E971C6BB3210062603F /* libsqlite3.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = 03A65E961C6BB3210062603F /* libsqlite3.tbd */; };
19A17073552293CA063BEA66 /* Result.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A17E723300E5ED3771DCB5 /* Result.swift */; };
19A1709C3E7A406E62293B2A /* Fixtures.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A17B93B48B5560E6E51791 /* Fixtures.swift */; };
19A170ACC97B19730FB7BA4D /* Connection+Aggregation.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A175A9CB446640AE6F2200 /* Connection+Aggregation.swift */; };
19A17152E32A9585831E3FE0 /* DateAndTimeFunctions.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A17BA55DABB480F9020C8A /* DateAndTimeFunctions.swift */; };
19A1717B10CC941ACB5533D6 /* FTS5.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1730E4390C775C25677D1 /* FTS5.swift */; };
19A171967CC511C4F6F773C9 /* RowTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A175C1F9CB3BBAB8FCEC7B /* RowTests.swift */; };
Expand All @@ -67,6 +68,7 @@
19A174D78559CD30679BCCCB /* FTS5Tests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1721B8984686B9963B45D /* FTS5Tests.swift */; };
19A1750CEE9B05267995CF3D /* FTS5.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1730E4390C775C25677D1 /* FTS5.swift */; };
19A175DFF47B84757E547C62 /* fixtures in Resources */ = {isa = PBXBuildFile; fileRef = 19A17E2695737FAB5D6086E3 /* fixtures */; };
19A176376CB6A94759F7980A /* Connection+Aggregation.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A175A9CB446640AE6F2200 /* Connection+Aggregation.swift */; };
19A176406BDE9D9C80CC9FA3 /* QueryIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A17BA6B4E282C1315A115C /* QueryIntegrationTests.swift */; };
19A1769C1F3A7542BECF50FF /* DateAndTimeFunctionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1729B75C33F9A0B9A89C1 /* DateAndTimeFunctionTests.swift */; };
19A177CC33F2E6A24AF90B02 /* CipherTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A17399EA9E61235D5D77BF /* CipherTests.swift */; };
Expand All @@ -75,6 +77,7 @@
19A1785195182AF8731A8BDA /* RowTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A175C1F9CB3BBAB8FCEC7B /* RowTests.swift */; };
19A1792C0520D4E83C2EB075 /* Errors.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A1710E73A46D5AC721CDA9 /* Errors.swift */; };
19A179A0C45377CB09BB358C /* CipherTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A17399EA9E61235D5D77BF /* CipherTests.swift */; };
19A179B59450FE7C4811AB8A /* Connection+Aggregation.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A175A9CB446640AE6F2200 /* Connection+Aggregation.swift */; };
19A179CCF9671E345E5A9811 /* Cipher.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A178A39ACA9667A62663CC /* Cipher.swift */; };
19A179E76EA6207669B60C1B /* Cipher.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A178A39ACA9667A62663CC /* Cipher.swift */; };
19A17C4B951CB054EE48AB1C /* CipherTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 19A17399EA9E61235D5D77BF /* CipherTests.swift */; };
Expand Down Expand Up @@ -237,6 +240,7 @@
19A1729B75C33F9A0B9A89C1 /* DateAndTimeFunctionTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DateAndTimeFunctionTests.swift; sourceTree = "<group>"; };
19A1730E4390C775C25677D1 /* FTS5.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FTS5.swift; sourceTree = "<group>"; };
19A17399EA9E61235D5D77BF /* CipherTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = CipherTests.swift; sourceTree = "<group>"; };
19A175A9CB446640AE6F2200 /* Connection+Aggregation.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Connection+Aggregation.swift"; sourceTree = "<group>"; };
19A175C1F9CB3BBAB8FCEC7B /* RowTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = RowTests.swift; sourceTree = "<group>"; };
19A178A39ACA9667A62663CC /* Cipher.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Cipher.swift; sourceTree = "<group>"; };
19A1794B7972D14330A65BBD /* Linux.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = Linux.md; sourceTree = "<group>"; };
Expand Down Expand Up @@ -462,6 +466,7 @@
19A1710E73A46D5AC721CDA9 /* Errors.swift */,
02A43A9722738CF100FEC494 /* Backup.swift */,
19A17E723300E5ED3771DCB5 /* Result.swift */,
19A175A9CB446640AE6F2200 /* Connection+Aggregation.swift */,
);
path = Core;
sourceTree = "<group>";
Expand Down Expand Up @@ -850,6 +855,7 @@
19A17FF4A10B44D3937C8CAC /* Errors.swift in Sources */,
19A1737286A74F3CF7412906 /* DateAndTimeFunctions.swift in Sources */,
19A17073552293CA063BEA66 /* Result.swift in Sources */,
19A179B59450FE7C4811AB8A /* Connection+Aggregation.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down Expand Up @@ -945,6 +951,7 @@
19A1792C0520D4E83C2EB075 /* Errors.swift in Sources */,
19A17E29278A12BC4F542506 /* DateAndTimeFunctions.swift in Sources */,
19A173EFEF0B3BD0B3ED406C /* Result.swift in Sources */,
19A176376CB6A94759F7980A /* Connection+Aggregation.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down Expand Up @@ -1008,6 +1015,7 @@
19A17490543609FCED53CACC /* Errors.swift in Sources */,
19A17152E32A9585831E3FE0 /* DateAndTimeFunctions.swift in Sources */,
19A17F1B3F0A3C96B5ED6D64 /* Result.swift in Sources */,
19A170ACC97B19730FB7BA4D /* Connection+Aggregation.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down
155 changes: 155 additions & 0 deletions Sources/SQLite/Core/Connection+Aggregation.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import Foundation
#if SQLITE_SWIFT_STANDALONE
import sqlite3
#elseif SQLITE_SWIFT_SQLCIPHER
import SQLCipher
#elseif os(Linux)
import CSQLite
#else
import SQLite3
#endif

extension Connection {
private typealias Aggregate = @convention(block) (Int, Context, Int32, Argv) -> Void

/// Creates or redefines a custom SQL aggregate.
///
/// - Parameters:
///
/// - aggregate: The name of the aggregate to create or redefine.
///
/// - argumentCount: The number of arguments that the aggregate takes. If
/// `nil`, the aggregate may take any number of arguments.
///
/// Default: `nil`
///
/// - deterministic: Whether or not the aggregate is deterministic (_i.e._
/// the aggregate always returns the same result for a given input).
///
/// Default: `false`
///
/// - step: A block of code to run for each row of an aggregation group.
/// The block is called with an array of raw SQL values mapped to the
/// aggregate’s parameters, and an UnsafeMutablePointer to a state
/// variable.
///
/// - final: A block of code to run after each row of an aggregation group
/// is processed. The block is called with an UnsafeMutablePointer to a
/// state variable, and should return a raw SQL value (or nil).
///
/// - state: A block of code to run to produce a fresh state variable for
/// each aggregation group. The block should return an
/// UnsafeMutablePointer to the fresh state variable.
public func createAggregation<T>(
_ functionName: String,
argumentCount: UInt? = nil,
deterministic: Bool = false,
step: @escaping ([Binding?], UnsafeMutablePointer<T>) -> Void,
final: @escaping (UnsafeMutablePointer<T>) -> Binding?,
state: @escaping () -> UnsafeMutablePointer<T>) {

let argc = argumentCount.map { Int($0) } ?? -1
let box: Aggregate = { (stepFlag: Int, context: Context, argc: Int32, argv: Argv) in
let nBytes = Int32(MemoryLayout<UnsafeMutablePointer<Int64>>.size)
guard let aggregateContext = sqlite3_aggregate_context(context, nBytes) else {
fatalError("Could not get aggregate context")
}
let mutablePointer = aggregateContext.assumingMemoryBound(to: UnsafeMutableRawPointer.self)
if stepFlag > 0 {
let arguments = argv.getBindings(argc: argc)
if aggregateContext.assumingMemoryBound(to: Int64.self).pointee == 0 {
mutablePointer.pointee = UnsafeMutableRawPointer(mutating: state())
}
step(arguments, mutablePointer.pointee.assumingMemoryBound(to: T.self))
} else {
let result = final(mutablePointer.pointee.assumingMemoryBound(to: T.self))
context.set(result: result)
}
}

func xStep(context: Context, argc: Int32, value: Argv) {
unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)(1, context, argc, value)
}

func xFinal(context: Context) {
unsafeBitCast(sqlite3_user_data(context), to: Aggregate.self)(0, context, 0, nil)
}

let flags = SQLITE_UTF8 | (deterministic ? SQLITE_DETERMINISTIC : 0)
let resultCode = sqlite3_create_function_v2(
handle,
functionName,
Int32(argc),
flags,
/* pApp */ unsafeBitCast(box, to: UnsafeMutableRawPointer.self),
/* xFunc */ nil, xStep, xFinal, /* xDestroy */ nil
)
if let result = Result(errorCode: resultCode, connection: self) {
fatalError("Error creating function: \(result)")
}
register(functionName, argc: argc, value: box)
}

func createAggregation<T: AnyObject>(
_ aggregate: String,
argumentCount: UInt? = nil,
deterministic: Bool = false,
initialValue: T,
reduce: @escaping (T, [Binding?]) -> T,
result: @escaping (T) -> Binding?
) {
let step: ([Binding?], UnsafeMutablePointer<UnsafeMutableRawPointer>) -> Void = { (bindings, ptr) in
let pointer = ptr.pointee.assumingMemoryBound(to: T.self)
let current = Unmanaged<T>.fromOpaque(pointer).takeRetainedValue()
let next = reduce(current, bindings)
ptr.pointee = Unmanaged.passRetained(next).toOpaque()
}

let final: (UnsafeMutablePointer<UnsafeMutableRawPointer>) -> Binding? = { ptr in
let pointer = ptr.pointee.assumingMemoryBound(to: T.self)
let obj = Unmanaged<T>.fromOpaque(pointer).takeRetainedValue()
let value = result(obj)
ptr.deallocate()
return value
}

let state: () -> UnsafeMutablePointer<UnsafeMutableRawPointer> = {
let pointer = UnsafeMutablePointer<UnsafeMutableRawPointer>.allocate(capacity: 1)
pointer.pointee = Unmanaged.passRetained(initialValue).toOpaque()
return pointer
}

createAggregation(aggregate, step: step, final: final, state: state)
}

func createAggregation<T>(
_ aggregate: String,
argumentCount: UInt? = nil,
deterministic: Bool = false,
initialValue: T,
reduce: @escaping (T, [Binding?]) -> T,
result: @escaping (T) -> Binding?
) {

let step: ([Binding?], UnsafeMutablePointer<T>) -> Void = { (bindings, pointer) in
let current = pointer.pointee
let next = reduce(current, bindings)
pointer.pointee = next
}

let final: (UnsafeMutablePointer<T>) -> Binding? = { pointer in
let value = result(pointer.pointee)
pointer.deallocate()
return value
}

let state: () -> UnsafeMutablePointer<T> = {
let pointer = UnsafeMutablePointer<T>.allocate(capacity: 1)
pointer.initialize(to: initialValue)
return pointer
}

createAggregation(aggregate, step: step, final: final, state: state)
}

}
Loading