Skip to content
This repository was archived by the owner on Mar 30, 2024. It is now read-only.

Added support for Notifications that use DispatchSource #59

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ matrix:
env: SCHEME="PostgreSQL-Package"

before_install:
- brew update
- gem install xcpretty
- brew tap vapor/tap
- brew update
- brew install vapor

install:
Expand Down
80 changes: 62 additions & 18 deletions Sources/PostgreSQL/Connection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ public final class Connection: ConnInfoInitializable {
// MARK: - CConnection

public typealias CConnection = OpaquePointer

public let cConnection: CConnection

@available(*, deprecated: 2.2, message: "needs to be optional or could cause runtime crash passing invalid reference to C")
public var cConnection: CConnection { return pgConnection! }

public private(set) var pgConnection: CConnection?

// MARK: - Init

Expand All @@ -27,14 +30,14 @@ public final class Connection: ConnInfoInitializable {
string = "host='\(hostname)' port='\(port)' dbname='\(database)' user='\(user)' password='\(password)' client_encoding='UTF8'"
}

cConnection = PQconnectdb(string)
pgConnection = PQconnectdb(string)
try validateConnection()
}

// MARK: - Deinit

deinit {
try? close()
close()
}

// MARK: - Execute
Expand Down Expand Up @@ -68,7 +71,7 @@ public final class Connection: ConnInfoInitializable {
}

let resultPointer: Result.Pointer? = PQexecParams(
cConnection,
pgConnection,
query,
Int32(binds.count),
types,
Expand All @@ -85,27 +88,35 @@ public final class Connection: ConnInfoInitializable {
// MARK: - Connection Status

public var isConnected: Bool {
return PQstatus(cConnection) == CONNECTION_OK
return pgConnection != nil && PQstatus(pgConnection) == CONNECTION_OK
}

public var status: ConnStatusType {
return PQstatus(cConnection)
guard pgConnection != nil else { return CONNECTION_BAD }
return PQstatus(pgConnection)
}

private func validateConnection() throws {
func validateConnection() throws {
guard pgConnection != nil else {
throw PostgreSQLError(code: .connectionDoesNotExist, connection: self)
}
guard isConnected else {
throw PostgreSQLError(code: .connectionFailure, connection: self)
}
}

public func reset() throws {
try validateConnection()
PQreset(cConnection)
guard let connection = pgConnection else { return }
PQreset(connection)
guard status == CONNECTION_OK else {
throw PostgreSQLError(code: .connectionFailure, connection: self)
}
}

public func close() throws {
try validateConnection()
PQfinish(cConnection)
public func close() {
guard pgConnection != nil else { return }
PQfinish(pgConnection)
pgConnection = nil
}

// MARK: - Transaction
Expand Down Expand Up @@ -152,6 +163,7 @@ public final class Connection: ConnInfoInitializable {
public let channel: String
public let payload: String?

/// internal initializer
init(pgNotify: PGnotify) {
channel = String(cString: pgNotify.relname)
pid = Int(pgNotify.be_pid)
Expand All @@ -160,8 +172,7 @@ public final class Connection: ConnInfoInitializable {
let string = String(cString: pgNotify.extra)
if !string.isEmpty {
payload = string
}
else {
} else {
payload = nil
}
}
Expand All @@ -171,12 +182,45 @@ public final class Connection: ConnInfoInitializable {
}
}

/// Creates a dispatch read source for this connection that will call `callback` on `queue` when a notification is received
///
/// - Parameter channel: the channel to register for
/// - Parameter queue: the queue to create the DispatchSource on
/// - Parameter callback: the callback
/// - Parameter notification: The notification received from the database
/// - Parameter error: Any error while reading the notification. If not nil, the source will have been canceled
/// - Returns: the dispatch socket to activate
/// - Throws: if fails to get the socket for the connection
public func listen(toChannel channel: String, queue: DispatchQueue, callback: @escaping (_ notification: Notification?, _ error: Error?) -> Void) throws -> DispatchSourceRead {
let sock = PQsocket(self.pgConnection)
guard sock >= 0 else {
throw PostgreSQLError(code: .ioError, reason: "failed to get socket for connection")
}
let src = DispatchSource.makeReadSource(fileDescriptor: sock, queue: queue)
src.setEventHandler { [weak self] in
guard let strongSelf = self else { return }
guard strongSelf.pgConnection != nil else {
callback(nil, PostgreSQLError(code: .connectionDoesNotExist, reason: "connection does not exist"))
return
}
PQconsumeInput(strongSelf.pgConnection)
while let pgNotify = PQnotifies(strongSelf.pgConnection) {
let notification = Notification(pgNotify: pgNotify.pointee)
callback(notification, nil)
PQfreemem(pgNotify)
}
}
try self.execute("LISTEN \(channel)")
return src
}

/// Registers as a listener on a specific notification channel.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering the new method should be the way to go, I would mark the old function as deprecated.

///
/// - Parameters:
/// - channel: The channel to register for.
/// - queue: The queue to perform the listening on.
/// - callback: Callback containing any received notification or error and a boolean which can be set to true to stop listening.
@available(*, deprecated: 2.2, message: "replaced with version using DispatchSource")
public func listen(toChannel channel: String, on queue: DispatchQueue = DispatchQueue.global(), callback: @escaping (Notification?, Error?, inout Bool) -> Void) {
queue.async {
var stop: Bool = false
Expand All @@ -190,9 +234,9 @@ public final class Connection: ConnInfoInitializable {
// Sleep to avoid looping continuously on cpu
sleep(1)

PQconsumeInput(self.cConnection)
PQconsumeInput(self.pgConnection)

while !stop, let pgNotify = PQnotifies(self.cConnection) {
while !stop, let pgNotify = PQnotifies(self.pgConnection) {
let notification = Notification(pgNotify: pgNotify.pointee)

callback(notification, nil, &stop)
Expand Down Expand Up @@ -234,7 +278,7 @@ public final class Connection: ConnInfoInitializable {
}

private func getBooleanParameterStatus(key: String, `default` defaultValue: Bool = false) -> Bool {
guard let value = PQparameterStatus(cConnection, "integer_datetimes") else {
guard let value = PQparameterStatus(pgConnection, "integer_datetimes") else {
return defaultValue
}
return String(cString: value) == "on"
Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgreSQL/Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ extension PostgreSQLError {
extension PostgreSQLError {
public init(code: Code, connection: Connection) {
let reason: String
if let error = PQerrorMessage(connection.cConnection) {
if let error = PQerrorMessage(connection.pgConnection) {
reason = String(cString: error)
}
else {
Expand Down
19 changes: 17 additions & 2 deletions Tests/PostgreSQLTests/ConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class ConnectionTests: XCTestCase {
("testConnInfoRaw", testConnInfoRaw),
("testConnectionFailure", testConnectionFailure),
("testConnectionSuccess", testConnectionSuccess),
("testInvalidConnection", testInvalidConnection),
]

var postgreSQL: PostgreSQL.Database!
Expand All @@ -18,11 +19,12 @@ class ConnectionTests: XCTestCase {
let conn = try postgreSQL.makeConnection()

let connection = try postgreSQL.makeConnection()
XCTAssert(conn.status == CONNECTION_OK)
let status = conn.status
XCTAssert(status == CONNECTION_OK)
XCTAssertTrue(connection.isConnected)

try connection.reset()
try connection.close()
connection.close()
XCTAssertFalse(connection.isConnected)
}

Expand Down Expand Up @@ -86,4 +88,17 @@ class ConnectionTests: XCTestCase {
XCTFail("Could not connect to database")
}
}

func testInvalidConnection() throws {
postgreSQL = PostgreSQL.Database.makeTest()
let connection = try postgreSQL.makeConnection()
try connection.validateConnection()
connection.close()
do {
try connection.validateConnection()
XCTFail("connection was valid after close")
} catch {
// connection was invalid
}
}
}
69 changes: 69 additions & 0 deletions Tests/PostgreSQLTests/PostgreSQLTests.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import XCTest
@testable import PostgreSQL
import Foundation
import Dispatch

class PostgreSQLTests: XCTestCase {
static let allTests = [
Expand Down Expand Up @@ -30,6 +31,9 @@ class PostgreSQLTests: XCTestCase {
("testUnsupportedObject", testUnsupportedObject),
("testNotification", testNotification),
("testNotificationWithPayload", testNotificationWithPayload),
("testDispatchNotification", testDispatchNotification),
("testDispatchNotificationInvalidConnection", testDispatchNotificationInvalidConnection),
("testDispatchNotificationWithPayload", testDispatchNotificationWithPayload),
("testQueryToNode", testQueryToNode)
]

Expand Down Expand Up @@ -779,6 +783,45 @@ class PostgreSQLTests: XCTestCase {
waitForExpectations(timeout: 5)
}

func testDispatchNotification() throws {
let conn1 = try postgreSQL.makeConnection()
let conn2 = try postgreSQL.makeConnection()

let testExpectation = expectation(description: "Receive notification")

let queue = DispatchQueue.global()
var source: DispatchSourceRead?
source = try! conn1.listen(toChannel: "test_channel1", queue: queue) { (notification, error) in
XCTAssertEqual(notification?.channel, "test_channel1")
XCTAssertNil(notification?.payload)
XCTAssertNil(error)

testExpectation.fulfill()
source?.cancel()
}
source?.resume()

sleep(1)

try conn2.notify(channel: "test_channel1", payload: nil)

waitForExpectations(timeout: 5)
}

func testDispatchNotificationInvalidConnection() throws {
let conn1 = try postgreSQL.makeConnection()
conn1.close()
do {
_ = try conn1.listen(toChannel: "test_channel1", queue: .global()) { (notification, error) in
XCTFail("callback should never be called")
}
XCTFail("exception should have been thrown because connection was not open")
} catch {
guard let pgerror = error as? PostgreSQLError else { XCTFail("incorrect error type"); return }
XCTAssertEqual(pgerror.code, .ioError)
}
}

func testNotificationWithPayload() throws {
let conn1 = try postgreSQL.makeConnection()
let conn2 = try postgreSQL.makeConnection()
Expand All @@ -801,6 +844,32 @@ class PostgreSQLTests: XCTestCase {
waitForExpectations(timeout: 5)
}

func testDispatchNotificationWithPayload() throws {
let conn1 = try postgreSQL.makeConnection()
let conn2 = try postgreSQL.makeConnection()

let testExpectation = expectation(description: "Receive notification with payload")

let queue = DispatchQueue.global()
var source: DispatchSourceRead?
source = try! conn1.listen(toChannel: "test_channel2", queue: queue) { (notification, error) in
XCTAssertEqual(notification?.channel, "test_channel2")
XCTAssertEqual(notification?.payload, "test_payload")
XCTAssertNotNil(notification?.payload)
XCTAssertNil(error)

testExpectation.fulfill()
source?.cancel()
}
source?.resume()

sleep(1)

try conn2.notify(channel: "test_channel2", payload: "test_payload")

waitForExpectations(timeout: 5)
}

func testQueryToNode() throws {
let conn = try postgreSQL.makeConnection()

Expand Down
2 changes: 1 addition & 1 deletion Tests/PostgreSQLTests/Utilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ extension PostgreSQL.Database {
let postgreSQL = try PostgreSQL.Database(
hostname: "127.0.0.1",
port: 5432,
database: "postgres",
database: "test",
user: "postgres",
password: ""
)
Expand Down