Skip to content

shutdown() should cancel the signal handlers installed by start() #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 48 additions & 13 deletions Sources/Lifecycle/Lifecycle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,22 @@ public protocol LifecycleTask {
var shutdownIfNotStarted: Bool { get }
func start(_ callback: @escaping (Error?) -> Void)
func shutdown(_ callback: @escaping (Error?) -> Void)
var logStart: Bool { get }
var logShutdown: Bool { get }
}

extension LifecycleTask {
public var shutdownIfNotStarted: Bool {
return false
}

public var logStart: Bool {
return true
}

public var logShutdown: Bool {
return true
}
}

// MARK: - LifecycleHandler
Expand Down Expand Up @@ -317,9 +327,14 @@ public struct ServiceLifecycle {
self.log("intercepted signal: \(signal)")
self.shutdown()
}, cancelAfterTrap: true)
self.underlying.shutdownGroup.notify(queue: .global()) {
signalSource.cancel()
}
// register cleanup as the last task
self.registerShutdown(label: "\(signal) shutdown hook cleanup", .sync {
// cancel if not already canceled by the trap
if !signalSource.isCancelled {
signalSource.cancel()
ServiceLifecycle.removeTrap(signal: signal)
}
})
}
}

Expand All @@ -343,22 +358,34 @@ extension ServiceLifecycle {
public static func trap(signal sig: Signal, handler: @escaping (Signal) -> Void, on queue: DispatchQueue = .global(), cancelAfterTrap: Bool = false) -> DispatchSourceSignal {
// on linux, we can call singal() once per process
self.trappedLock.withLockVoid {
if !trapped.contains(sig.rawValue) {
if !self.trapped.contains(sig.rawValue) {
signal(sig.rawValue, SIG_IGN)
trapped.insert(sig.rawValue)
self.trapped.insert(sig.rawValue)
}
}
let signalSource = DispatchSource.makeSignalSource(signal: sig.rawValue, queue: queue)
signalSource.setEventHandler {
// run handler first
handler(sig)
// then cancel trap if so requested
if cancelAfterTrap {
signalSource.cancel()
self.removeTrap(signal: sig)
}
handler(sig)
}
signalSource.resume()
return signalSource
}

public static func removeTrap(signal sig: Signal) {
self.trappedLock.withLockVoid {
if self.trapped.contains(sig.rawValue) {
signal(sig.rawValue, SIG_DFL)
self.trapped.remove(sig.rawValue)
}
}
}

/// A system signal
public struct Signal: Equatable, CustomStringConvertible {
internal var rawValue: CInt
Expand Down Expand Up @@ -433,7 +460,7 @@ struct ShutdownError: Error {
public class ComponentLifecycle: LifecycleTask {
public let label: String
fileprivate let logger: Logger
internal let shutdownGroup = DispatchGroup()
fileprivate let shutdownGroup = DispatchGroup()

private var state = State.idle([])
private let stateLock = Lock()
Expand Down Expand Up @@ -596,13 +623,15 @@ public class ComponentLifecycle: LifecycleTask {

private func startTask(on queue: DispatchQueue, tasks: [LifecycleTask], index: Int, callback: @escaping ([LifecycleTask], Error?) -> Void) {
// async barrier
let start = { (callback) -> Void in queue.async { tasks[index].start(callback) } }
let callback = { (index, error) -> Void in queue.async { callback(index, error) } }
let start = { callback in queue.async { tasks[index].start(callback) } }
let callback = { index, error in queue.async { callback(index, error) } }

if index >= tasks.count {
return callback(tasks, nil)
}
self.logger.info("starting tasks [\(tasks[index].label)]")
if tasks[index].logStart {
self.logger.info("starting tasks [\(tasks[index].label)]")
}
let startTime = DispatchTime.now()
start { error in
Timer(label: "\(self.label).\(tasks[index].label).lifecycle.start").recordNanoseconds(DispatchTime.now().uptimeNanoseconds - startTime.uptimeNanoseconds)
Expand Down Expand Up @@ -642,14 +671,16 @@ public class ComponentLifecycle: LifecycleTask {

private func shutdownTask(on queue: DispatchQueue, tasks: [LifecycleTask], index: Int, errors: [String: Error]?, callback: @escaping ([String: Error]?) -> Void) {
// async barrier
let shutdown = { (callback) -> Void in queue.async { tasks[index].shutdown(callback) } }
let callback = { (errors) -> Void in queue.async { callback(errors) } }
let shutdown = { callback in queue.async { tasks[index].shutdown(callback) } }
let callback = { errors in queue.async { callback(errors) } }

if index >= tasks.count {
return callback(errors)
}

self.logger.info("stopping tasks [\(tasks[index].label)]")
if tasks[index].logShutdown {
self.logger.info("stopping tasks [\(tasks[index].label)]")
}
let startTime = DispatchTime.now()
shutdown { error in
Timer(label: "\(self.label).\(tasks[index].label).lifecycle.shutdown").recordNanoseconds(DispatchTime.now().uptimeNanoseconds - startTime.uptimeNanoseconds)
Expand Down Expand Up @@ -739,12 +770,16 @@ internal struct _LifecycleTask: LifecycleTask {
let shutdownIfNotStarted: Bool
let start: LifecycleHandler
let shutdown: LifecycleHandler
let logStart: Bool
let logShutdown: Bool

init(label: String, shutdownIfNotStarted: Bool? = nil, start: LifecycleHandler, shutdown: LifecycleHandler) {
self.label = label
self.shutdownIfNotStarted = shutdownIfNotStarted ?? start.noop
self.start = start
self.shutdown = shutdown
self.logStart = !start.noop
self.logShutdown = !shutdown.noop
}

func start(_ callback: @escaping (Error?) -> Void) {
Expand Down
4 changes: 2 additions & 2 deletions Sources/Lifecycle/Locks.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ extension Lock {
/// - Parameter body: The block to execute while holding the lock.
/// - Returns: The value returned by the block.
@inlinable
internal func withLock<T>(_ body: () throws -> T) rethrows -> T {
func withLock<T>(_ body: () throws -> T) rethrows -> T {
self.lock()
defer {
self.unlock()
Expand All @@ -91,7 +91,7 @@ extension Lock {

// specialise Void return (for performance)
@inlinable
internal func withLockVoid(_ body: () throws -> Void) rethrows {
func withLockVoid(_ body: () throws -> Void) rethrows {
try self.withLock(body)
}
}
4 changes: 2 additions & 2 deletions Tests/LifecycleTests/ComponentLifecycleTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ final class ComponentLifecycleTests: XCTestCase {
dispatchPrecondition(condition: .onQueue(.global()))
XCTAssertTrue(startCalls.contains(id))
stopCalls.append(id)
})
})
}
lifecycle.register(items)

Expand Down Expand Up @@ -92,7 +92,7 @@ final class ComponentLifecycleTests: XCTestCase {
dispatchPrecondition(condition: .onQueue(testQueue))
XCTAssertTrue(startCalls.contains(id))
stopCalls.append(id)
})
})
}
lifecycle.register(items)

Expand Down
1 change: 1 addition & 0 deletions Tests/LifecycleTests/ServiceLifecycleTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ extension ServiceLifecycleTests {
("testSignalDescription", testSignalDescription),
("testBacktracesInstalledOnce", testBacktracesInstalledOnce),
("testRepeatShutdown", testRepeatShutdown),
("testShutdownCancelSignal", testShutdownCancelSignal),
]
}
}
43 changes: 43 additions & 0 deletions Tests/LifecycleTests/ServiceLifecycleTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,47 @@ final class ServiceLifecycleTests: XCTestCase {

XCTAssertEqual(attempts, count)
}

func testShutdownCancelSignal() {
if ProcessInfo.processInfo.environment["SKIP_SIGNAL_TEST"].flatMap(Bool.init) ?? false {
print("skipping testShutdownCancelSignal")
return
}

struct Service {
static let signal = ServiceLifecycle.Signal.ALRM

let lifecycle: ServiceLifecycle

init() {
self.lifecycle = ServiceLifecycle(configuration: .init(shutdownSignal: [Service.signal]))
self.lifecycle.register(GoodItem())
}
}

let service = Service()
service.lifecycle.start { error in
XCTAssertNil(error, "not expecting error")
kill(getpid(), Service.signal.rawValue)
}
service.lifecycle.wait()

var count = 0
let sync = DispatchGroup()
sync.enter()
let signalSource = ServiceLifecycle.trap(signal: Service.signal, handler: { _ in
count = count + 1 // not thread safe but fine for this purpose
sync.leave()
}, cancelAfterTrap: false)

// since we are removing the hook added by lifecycle on shutdown,
// this will fail unless a new hook is set up as done above
kill(getpid(), Service.signal.rawValue)

XCTAssertEqual(.success, sync.wait(timeout: .now() + 2))
XCTAssertEqual(count, 1)

signalSource.cancel()
ServiceLifecycle.removeTrap(signal: Service.signal)
}
}