Skip to content

Fix thread leak in FileDownloadDelegate #614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 18, 2022
65 changes: 57 additions & 8 deletions Sources/AsyncHTTPClient/FileDownloadDelegate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
public typealias Response = Progress

private let filePath: String
private let io: NonBlockingFileIO
private var io: NonBlockingFileIO?
private let reportHead: ((HTTPResponseHead) -> Void)?
private let reportProgress: ((Progress) -> Void)?

Expand All @@ -47,20 +47,60 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
/// the total byte count and download byte count passed to it as arguments. The callbacks
/// will be invoked in the same threading context that the delegate itself is invoked,
/// as controlled by `EventLoopPreference`.
public init(
public convenience init(
path: String,
pool: NIOThreadPool = NIOThreadPool(numberOfThreads: 1),
pool: NIOThreadPool,
reportHead: ((HTTPResponseHead) -> Void)? = nil,
reportProgress: ((Progress) -> Void)? = nil
) throws {
pool.start()
self.io = NonBlockingFileIO(threadPool: pool)
try self.init(path: path, pool: .some(pool), reportHead: reportHead, reportProgress: reportProgress)
}

/// Initializes a new file download delegate and uses the shared thread pool of the ``HTTPClient`` for file I/O.
///
/// - parameters:
/// - path: Path to a file you'd like to write the download to.
/// - reportHead: A closure called when the response head is available.
/// - reportProgress: A closure called when a body chunk has been downloaded, with
/// the total byte count and download byte count passed to it as arguments. The callbacks
/// will be invoked in the same threading context that the delegate itself is invoked,
/// as controlled by `EventLoopPreference`.
public convenience init(
path: String,
reportHead: ((HTTPResponseHead) -> Void)? = nil,
reportProgress: ((Progress) -> Void)? = nil
) throws {
try self.init(path: path, pool: nil, reportHead: reportHead, reportProgress: reportProgress)
}

private init(
path: String,
pool: NIOThreadPool?,
reportHead: ((HTTPResponseHead) -> Void)? = nil,
reportProgress: ((Progress) -> Void)? = nil
) throws {
if let pool = pool {
self.io = NonBlockingFileIO(threadPool: pool)
} else {
// we should use the shared thread pool from the HTTPClient which
// we will get from the `HTTPClient.Task`
self.io = nil
}

self.filePath = path

self.reportHead = reportHead
self.reportProgress = reportProgress
}

public func provideSharedThreadPool(fileIOPool: NIOThreadPool) {
guard self.io == nil else {
// user has provided their own thread pool
return
}
self.io = NonBlockingFileIO(threadPool: fileIOPool)
}

public func didReceiveHead(
task: HTTPClient.Task<Response>,
_ head: HTTPResponseHead
Expand All @@ -79,24 +119,33 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
task: HTTPClient.Task<Response>,
_ buffer: ByteBuffer
) -> EventLoopFuture<Void> {
let io: NonBlockingFileIO = {
guard let io = self.io else {
let pool = task.fileIOThreadPool
let io = NonBlockingFileIO(threadPool: pool)
self.io = io
return io
}
return io
}()
self.progress.receivedBytes += buffer.readableBytes
self.reportProgress?(self.progress)

let writeFuture: EventLoopFuture<Void>
if let fileHandleFuture = self.fileHandleFuture {
writeFuture = fileHandleFuture.flatMap {
self.io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
}
} else {
let fileHandleFuture = self.io.openFile(
let fileHandleFuture = io.openFile(
path: self.filePath,
mode: .write,
flags: .allowFileCreation(),
eventLoop: task.eventLoop
)
self.fileHandleFuture = fileHandleFuture
writeFuture = fileHandleFuture.flatMap {
self.io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
}
}

Expand Down
42 changes: 36 additions & 6 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ public class HTTPClient {
let eventLoopGroupProvider: EventLoopGroupProvider
let configuration: Configuration
let poolManager: HTTPConnectionPool.Manager

/// Shared thread pool used for file IO. It is given to the user through ``HTTPClientResponseDelegate/provideSharedThreadPool(fileIOPool:)-6phmu``
private var fileIOThreadPool: NIOThreadPool?
private var fileIOThreadPoolLock = Lock()

private var state: State
private let stateLock = Lock()

Expand Down Expand Up @@ -213,6 +218,16 @@ public class HTTPClient {
}
}

private func shutdownFileIOThreadPool(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) {
self.fileIOThreadPoolLock.withLockVoid {
guard let fileIOThreadPool = fileIOThreadPool else {
callback(nil)
return
}
fileIOThreadPool.shutdownGracefully(queue: queue, callback)
}
}

private func shutdown(requiresCleanClose: Bool, queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) {
do {
try self.stateLock.withLock {
Expand Down Expand Up @@ -241,15 +256,28 @@ public class HTTPClient {
let error: Error? = (requiresClean && unclean) ? HTTPClientError.uncleanShutdown : nil
return (callback, error)
}

self.shutdownEventLoop(queue: queue) { error in
let reportedError = error ?? uncleanError
callback(reportedError)
self.shutdownFileIOThreadPool(queue: queue) { ioThreadPoolError in
self.shutdownEventLoop(queue: queue) { error in
let reportedError = error ?? ioThreadPoolError ?? uncleanError
callback(reportedError)
}
}
}
}
}

private func makeOrGetFileIOThreadPool() -> NIOThreadPool {
self.fileIOThreadPoolLock.withLock {
guard let fileIOThreadPool = fileIOThreadPool else {
let fileIOThreadPool = NIOThreadPool(numberOfThreads: ProcessInfo.processInfo.processorCount)
fileIOThreadPool.start()
self.fileIOThreadPool = fileIOThreadPool
return fileIOThreadPool
}
return fileIOThreadPool
}
}

/// Execute `GET` request using specified URL.
///
/// - parameters:
Expand Down Expand Up @@ -562,6 +590,7 @@ public class HTTPClient {
case .testOnly_exact(_, delegateOn: let delegateEL):
taskEL = delegateEL
}

logger.trace("selected EventLoop for task given the preference",
metadata: ["ahc-eventloop": "\(taskEL)",
"ahc-el-preference": "\(eventLoopPreference)"])
Expand All @@ -574,7 +603,8 @@ public class HTTPClient {
logger.debug("client is shutting down, failing request")
return Task<Delegate.Response>.failedTask(eventLoop: taskEL,
error: HTTPClientError.alreadyShutdown,
logger: logger)
logger: logger,
makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool)
}
}

Expand All @@ -597,7 +627,7 @@ public class HTTPClient {
}
}()

let task = Task<Delegate.Response>(eventLoop: taskEL, logger: logger)
let task = Task<Delegate.Response>(eventLoop: taskEL, logger: logger, makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool)
do {
let requestBag = try RequestBag(
request: request,
Expand Down
20 changes: 16 additions & 4 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Logging
import NIOConcurrencyHelpers
import NIOCore
import NIOHTTP1
import NIOPosix
import NIOSSL

extension HTTPClient {
Expand Down Expand Up @@ -502,7 +503,7 @@ public protocol HTTPClientResponseDelegate: AnyObject {
}

extension HTTPClientResponseDelegate {
/// Default implementation of ``HTTPClientResponseDelegate/didSendRequestHead(task:_:)-6khai``.
/// Default implementation of ``HTTPClientResponseDelegate/didSendRequest(task:)-9od5p``.
///
/// By default, this does nothing.
public func didSendRequestHead(task: HTTPClient.Task<Response>, _ head: HTTPRequestHead) {}
Expand Down Expand Up @@ -622,15 +623,26 @@ extension HTTPClient {
private var _isCancelled: Bool = false
private var _taskDelegate: HTTPClientTaskDelegate?
private let lock = Lock()
private let makeOrGetFileIOThreadPool: () -> NIOThreadPool

init(eventLoop: EventLoop, logger: Logger) {
public var fileIOThreadPool: NIOThreadPool {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that we actually don't need to make this public. @Lukasa any objections to me making it internal?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None at all.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in c3c68b2

self.makeOrGetFileIOThreadPool()
}

init(eventLoop: EventLoop, logger: Logger, makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool) {
self.eventLoop = eventLoop
self.promise = eventLoop.makePromise()
self.logger = logger
self.makeOrGetFileIOThreadPool = makeOrGetFileIOThreadPool
}

static func failedTask(eventLoop: EventLoop, error: Error, logger: Logger) -> Task<Response> {
let task = self.init(eventLoop: eventLoop, logger: logger)
static func failedTask(
eventLoop: EventLoop,
error: Error,
logger: Logger,
makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool
) -> Task<Response> {
let task = self.init(eventLoop: eventLoop, logger: logger, makeOrGetFileIOThreadPool: makeOrGetFileIOThreadPool)
task.promise.fail(error)
return task
}
Expand Down
11 changes: 11 additions & 0 deletions Tests/AsyncHTTPClientTests/RequestBagTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,17 @@ final class RequestBagTests: XCTestCase {
}
}

extension HTTPClient.Task {
convenience init(
eventLoop: EventLoop,
logger: Logger
) {
self.init(eventLoop: eventLoop, logger: logger) {
preconditionFailure("thread pool not needed in tests")
}
}
}

class UploadCountingDelegate: HTTPClientResponseDelegate {
typealias Response = Void

Expand Down