Skip to content

Add custom JSON encoder / decoder support #285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,70 +79,3 @@ public struct MySQLConfiguration {
self._hostname = hostname
}
}

public struct MySQLConnectionSource: ConnectionPoolSource {
public let configuration: MySQLConfiguration

public init(configuration: MySQLConfiguration) {
self.configuration = configuration
}

public func makeConnection(logger: Logger, on eventLoop: EventLoop) -> EventLoopFuture<MySQLConnection> {
let address: SocketAddress
do {
address = try self.configuration.address()
} catch {
return eventLoop.makeFailedFuture(error)
}
return MySQLConnection.connect(
to: address,
username: self.configuration.username,
database: self.configuration.database ?? self.configuration.username,
password: self.configuration.password,
tlsConfiguration: self.configuration.tlsConfiguration,
logger: logger,
on: eventLoop
)
}
}

extension MySQLConnection: ConnectionPoolItem { }

struct MissingColumn: Error {
let column: String
}

extension MySQLRow: SQLRow {
public var allColumns: [String] {
self.columnDefinitions.map { $0.name }
}

public func contains(column: String) -> Bool {
self.columnDefinitions.contains { $0.name == column }
}

public func decodeNil(column: String) throws -> Bool {
guard let data = self.column(column) else {
return true
}
return data.buffer == nil
}

public func decode<D>(column: String, as type: D.Type) throws -> D where D : Decodable {
guard let data = self.column(column) else {
throw MissingColumn(column: column)
}
return try MySQLDataDecoder().decode(D.self, from: data)
}
}

public struct SQLRaw: SQLExpression {
public var string: String
public init(_ string: String) {
self.string = string
}

public func serialize(to serializer: inout SQLSerializer) {
serializer.write(self.string)
}
}
27 changes: 27 additions & 0 deletions Sources/MySQLKit/MySQLConnectionSource.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
public struct MySQLConnectionSource: ConnectionPoolSource {
public let configuration: MySQLConfiguration

public init(configuration: MySQLConfiguration) {
self.configuration = configuration
}

public func makeConnection(logger: Logger, on eventLoop: EventLoop) -> EventLoopFuture<MySQLConnection> {
let address: SocketAddress
do {
address = try self.configuration.address()
} catch {
return eventLoop.makeFailedFuture(error)
}
return MySQLConnection.connect(
to: address,
username: self.configuration.username,
database: self.configuration.database ?? self.configuration.username,
password: self.configuration.password,
tlsConfiguration: self.configuration.tlsConfiguration,
logger: logger,
on: eventLoop
)
}
}

extension MySQLConnection: ConnectionPoolItem { }
19 changes: 13 additions & 6 deletions Sources/MySQLKit/MySQLDataDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ extension MySQLData {
}

public struct MySQLDataDecoder {
public init() {}
let json: JSONDecoder

public init(json: JSONDecoder = .init()) {
self.json = json
}

public func decode<T>(_ type: T.Type, from data: MySQLData) throws -> T
where T: Decodable
Expand All @@ -40,7 +44,7 @@ public struct MySQLDataDecoder {
}
return value as! T
} else {
return try T.init(from: _Decoder(data: data))
return try T.init(from: _Decoder(data: data, json: self.json))
}
}

Expand All @@ -52,22 +56,25 @@ public struct MySQLDataDecoder {
var userInfo: [CodingUserInfoKey : Any] {
return [:]
}

let data: MySQLData
init(data: MySQLData) {
let json: JSONDecoder

init(data: MySQLData, json: JSONDecoder) {
self.data = data
self.json = json
}

func unkeyedContainer() throws -> UnkeyedDecodingContainer {
try JSONDecoder()
try self.json
.decode(DecoderUnwrapper.self, from: self.data.data!)
.decoder.unkeyedContainer()
}

func container<Key>(keyedBy type: Key.Type) throws -> KeyedDecodingContainer<Key>
where Key : CodingKey
{
try JSONDecoder()
try self.json
.decode(DecoderUnwrapper.self, from: self.data.data!)
.decoder.container(keyedBy: Key.self)
}
Expand Down
8 changes: 6 additions & 2 deletions Sources/MySQLKit/MySQLDataEncoder.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import Foundation

public struct MySQLDataEncoder {
public init() { }
let json: JSONEncoder

public init(json: JSONEncoder = .init()) {
self.json = json
}

public func encode(_ value: Encodable) throws -> MySQLData {
if let custom = value as? MySQLDataConvertible, let data = custom.mysqlData {
Expand All @@ -13,7 +17,7 @@ public struct MySQLDataEncoder {
return data
} else {
var buffer = ByteBufferAllocator().buffer(capacity: 0)
try buffer.writeBytes(JSONEncoder().encode(_Wrapper(value)))
try buffer.writeBytes(self.json.encode(_Wrapper(value)))
return MySQLData(
type: .string,
format: .text,
Expand Down
13 changes: 9 additions & 4 deletions Sources/MySQLKit/MySQLDatabase+SQL.swift
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
extension MySQLDatabase {
public func sql() -> SQLDatabase {
_MySQLSQLDatabase(database: self)
public func sql(
encoder: MySQLDataEncoder = .init(),
decoder: MySQLDataDecoder = .init()
) -> SQLDatabase {
_MySQLSQLDatabase(database: self, encoder: encoder, decoder: decoder)
}
}


private struct _MySQLSQLDatabase {
let database: MySQLDatabase
let encoder: MySQLDataEncoder
let decoder: MySQLDataDecoder
}

extension _MySQLSQLDatabase: SQLDatabase {
Expand All @@ -26,9 +31,9 @@ extension _MySQLSQLDatabase: SQLDatabase {
let (sql, binds) = self.serialize(query)
do {
return try self.database.query(sql, binds.map { encodable in
return try MySQLDataEncoder().encode(encodable)
return try self.encoder.encode(encodable)
}, onRow: { row in
onRow(row)
onRow(row.sql(decoder: self.decoder))
})
} catch {
return self.eventLoop.makeFailedFuture(error)
Expand Down
36 changes: 36 additions & 0 deletions Sources/MySQLKit/MySQLRow+SQL.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
extension MySQLRow {
public func sql(decoder: MySQLDataDecoder = .init()) -> SQLRow {
_MySQLSQLRow(row: self, decoder: decoder)
}
}

struct MissingColumn: Error {
let column: String
}

private struct _MySQLSQLRow: SQLRow {
let row: MySQLRow
let decoder: MySQLDataDecoder

var allColumns: [String] {
self.row.columnDefinitions.map { $0.name }
}

func contains(column: String) -> Bool {
self.row.columnDefinitions.contains { $0.name == column }
}

func decodeNil(column: String) throws -> Bool {
guard let data = self.row.column(column) else {
return true
}
return data.buffer == nil
}

func decode<D>(column: String, as type: D.Type) throws -> D where D : Decodable {
guard let data = self.row.column(column) else {
throw MissingColumn(column: column)
}
return try self.decoder.decode(D.self, from: data)
}
}
1 change: 0 additions & 1 deletion Tests/LinuxMain.swift
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
#error("Please test with `swift test --enable-test-discovery`")

75 changes: 62 additions & 13 deletions Tests/MySQLKitTests/MySQLKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,86 @@ class MySQLKitTests: XCTestCase {
let name: String?
}

let rows = try self.db.raw("SELECT 1 as `id`, null as `name`")
let rows = try self.sql.raw("SELECT 1 as `id`, null as `name`")
.all(decoding: Person.self).wait()
XCTAssertEqual(rows[0].id, 1)
XCTAssertEqual(rows[0].name, nil)
}

var db: SQLDatabase {
self.connection.sql()
func testCustomJSONCoder() throws {
let encoder = JSONEncoder()
encoder.dateEncodingStrategy = .secondsSince1970
let decoder = JSONDecoder()
decoder.dateDecodingStrategy = .secondsSince1970
let db = self.mysql.sql(encoder: .init(json: encoder), decoder: .init(json: decoder))

struct Foo: Codable, Equatable {
var bar: Bar
}
struct Bar: Codable, Equatable {
var baz: Date
}

try db.create(table: "foo")
.column("bar", type: .custom(SQLRaw("JSON")))
.run().wait()
defer { try! db.drop(table: "foo").ifExists().run().wait() }

let foo = Foo(bar: .init(baz: .init(timeIntervalSince1970: 1337)))
try db.insert(into: "foo").model(foo).run().wait()

let rows = try db.select().columns("*").from("foo").all(decoding: Foo.self).wait()
XCTAssertEqual(rows, [foo])
}

var sql: SQLDatabase {
self.mysql.sql()
}

var mysql: MySQLDatabase {
self.pool.pool(for: self.eventLoopGroup.next())
.database(logger: .init(label: "codes.vapor.mysql"))
}

var benchmark: SQLBenchmarker {
.init(on: self.db)
.init(on: self.sql)
}

var eventLoopGroup: EventLoopGroup!
var connection: MySQLConnection!
var pool: EventLoopGroupConnectionPool<MySQLConnectionSource>!

override func setUpWithError() throws {
try super.setUpWithError()
XCTAssertTrue(isLoggingConfigured)
self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2)
self.connection = try MySQLConnection.test(
on: self.eventLoopGroup.next()
).wait()
_ = try self.connection.simpleQuery("DROP DATABASE vapor_database").wait()
_ = try self.connection.simpleQuery("CREATE DATABASE vapor_database").wait()
_ = try self.connection.simpleQuery("USE vapor_database").wait()
self.pool = .init(
source: .init(configuration: .init(
hostname: env("MYSQL_HOSTNAME") ?? "localhost",
port: 3306,
username: "vapor_username",
password: "vapor_password",
database: "vapor_database",
tlsConfiguration: .forClient(certificateVerification: .none)
)),
maxConnectionsPerEventLoop: 2,
requestTimeout: .seconds(30),
logger: .init(label: "codes.vapor.mysql"),
on: self.eventLoopGroup
)

// Reset database.
_ = try self.mysql.withConnection { conn in
return conn.simpleQuery("DROP DATABASE vapor_database").flatMap { _ in
conn.simpleQuery("CREATE DATABASE vapor_database")
}.flatMap { _ in
conn.simpleQuery("USE vapor_database")
}
}.wait()
}

override func tearDownWithError() throws {
try self.connection?.close().wait()
self.connection = nil
try self.pool.syncShutdownGracefully()
self.pool = nil
try self.eventLoopGroup.syncShutdownGracefully()
self.eventLoopGroup = nil
try super.tearDownWithError()
Expand Down