-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[test] Update mock server for Swift 6 compliance (#463)
Rewrite to MockServer (used only for performance testing and part of a separate target) to comply with Swift 6 language mode and concurrency. The new MockServer now uses `NIOAsyncChannel` and structured concurrency. Instead of adding support to [MAX_REQUEST](https://github.com/swift-server/swift-aws-lambda-runtime/blob/11756b4e00ca75894826b41666bdae506b6eb496/Sources/AWSLambdaRuntimeCore/LambdaConfiguration.swift#L53) environment variable like v1 did, we implemented support for `MAX_REQUEST` environment variable in the MockServer itself. It closes the connection and shutdown the server after servicing MAX_INVOCATIONS Lambda requests). This allow to add the MAX_REQUEST penalty on the MockServer and not on the LambdaRuntimeClient. However, currently, the LambdaRuntimeClient does not shutdown when the MockServer ends. I created #465 to track this issue. See #377
- Loading branch information
Showing
3 changed files
with
301 additions
and
180 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,298 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This source file is part of the SwiftAWSLambdaRuntime open source project | ||
// | ||
// Copyright (c) 2017-2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors | ||
// Licensed under Apache License v2.0 | ||
// | ||
// See LICENSE.txt for license information | ||
// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
import Logging | ||
import NIOCore | ||
import NIOHTTP1 | ||
import NIOPosix | ||
import Synchronization | ||
|
||
// for UUID and Date | ||
#if canImport(FoundationEssentials) | ||
import FoundationEssentials | ||
#else | ||
import Foundation | ||
#endif | ||
|
||
@main | ||
struct HttpServer { | ||
/// The server's host. (default: 127.0.0.1) | ||
private let host: String | ||
/// The server's port. (default: 7000) | ||
private let port: Int | ||
/// The server's event loop group. (default: MultiThreadedEventLoopGroup.singleton) | ||
private let eventLoopGroup: MultiThreadedEventLoopGroup | ||
/// the mode. Are we mocking a server for a Lambda function that expects a String or a JSON document? (default: string) | ||
private let mode: Mode | ||
/// the number of connections this server must accept before shutting down (default: 1) | ||
private let maxInvocations: Int | ||
/// the logger (control verbosity with LOG_LEVEL environment variable) | ||
private let logger: Logger | ||
|
||
static func main() async throws { | ||
var log = Logger(label: "MockServer") | ||
log.logLevel = env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info | ||
|
||
let server = HttpServer( | ||
host: env("HOST") ?? "127.0.0.1", | ||
port: env("PORT").flatMap(Int.init) ?? 7000, | ||
eventLoopGroup: .singleton, | ||
mode: env("MODE").flatMap(Mode.init) ?? .string, | ||
maxInvocations: env("MAX_INVOCATIONS").flatMap(Int.init) ?? 1, | ||
logger: log | ||
) | ||
try await server.run() | ||
} | ||
|
||
/// This method starts the server and handles one unique incoming connections | ||
/// The Lambda function will send two HTTP requests over this connection: one for the next invocation and one for the response. | ||
private func run() async throws { | ||
let channel = try await ServerBootstrap(group: self.eventLoopGroup) | ||
.serverChannelOption(.backlog, value: 256) | ||
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1) | ||
.childChannelOption(.maxMessagesPerRead, value: 1) | ||
.bind( | ||
host: self.host, | ||
port: self.port | ||
) { channel in | ||
channel.eventLoop.makeCompletedFuture { | ||
|
||
try channel.pipeline.syncOperations.configureHTTPServerPipeline( | ||
withErrorHandling: true | ||
) | ||
|
||
return try NIOAsyncChannel( | ||
wrappingChannelSynchronously: channel, | ||
configuration: NIOAsyncChannel.Configuration( | ||
inboundType: HTTPServerRequestPart.self, | ||
outboundType: HTTPServerResponsePart.self | ||
) | ||
) | ||
} | ||
} | ||
|
||
logger.info( | ||
"Server started and listening", | ||
metadata: [ | ||
"host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")", | ||
"port": "\(channel.channel.localAddress?.port ?? 0)", | ||
"maxInvocations": "\(self.maxInvocations)", | ||
] | ||
) | ||
|
||
// This counter is used to track the number of incoming connections. | ||
// This mock servers accepts n TCP connection then shutdowns | ||
let connectionCounter = SharedCounter(maxValue: self.maxInvocations) | ||
|
||
// We are handling each incoming connection in a separate child task. It is important | ||
// to use a discarding task group here which automatically discards finished child tasks. | ||
// A normal task group retains all child tasks and their outputs in memory until they are | ||
// consumed by iterating the group or by exiting the group. Since, we are never consuming | ||
// the results of the group we need the group to automatically discard them; otherwise, this | ||
// would result in a memory leak over time. | ||
try await withThrowingDiscardingTaskGroup { group in | ||
try await channel.executeThenClose { inbound in | ||
for try await connectionChannel in inbound { | ||
|
||
let counter = connectionCounter.current() | ||
logger.trace("Handling new connection", metadata: ["connectionNumber": "\(counter)"]) | ||
|
||
group.addTask { | ||
await self.handleConnection(channel: connectionChannel) | ||
logger.trace("Done handling connection", metadata: ["connectionNumber": "\(counter)"]) | ||
} | ||
|
||
if connectionCounter.increment() { | ||
logger.info( | ||
"Maximum number of connections reached, shutting down after current connection", | ||
metadata: ["maxConnections": "\(self.maxInvocations)"] | ||
) | ||
break // this causes the server to shutdown after handling the connection | ||
} | ||
} | ||
} | ||
} | ||
logger.info("Server shutting down") | ||
} | ||
|
||
/// This method handles a single connection by responsing hard coded value to a Lambda function request. | ||
/// It handles two requests: one for the next invocation and one for the response. | ||
/// when the maximum number of requests is reached, it closes the connection. | ||
private func handleConnection( | ||
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart> | ||
) async { | ||
|
||
var requestHead: HTTPRequestHead! | ||
var requestBody: ByteBuffer? | ||
|
||
// each Lambda invocation results in TWO HTTP requests (next and response) | ||
let requestCount = SharedCounter(maxValue: 2) | ||
|
||
// Note that this method is non-throwing and we are catching any error. | ||
// We do this since we don't want to tear down the whole server when a single connection | ||
// encounters an error. | ||
do { | ||
try await channel.executeThenClose { inbound, outbound in | ||
for try await inboundData in inbound { | ||
let requestNumber = requestCount.current() | ||
logger.trace("Handling request", metadata: ["requestNumber": "\(requestNumber)"]) | ||
|
||
if case .head(let head) = inboundData { | ||
logger.trace("Received request head", metadata: ["head": "\(head)"]) | ||
requestHead = head | ||
} | ||
if case .body(let body) = inboundData { | ||
logger.trace("Received request body", metadata: ["body": "\(body)"]) | ||
requestBody = body | ||
} | ||
if case .end(let end) = inboundData { | ||
logger.trace("Received request end", metadata: ["end": "\(String(describing: end))"]) | ||
|
||
precondition(requestHead != nil, "Received .end without .head") | ||
let (responseStatus, responseHeaders, responseBody) = self.processRequest( | ||
requestHead: requestHead, | ||
requestBody: requestBody | ||
) | ||
|
||
try await self.sendResponse( | ||
responseStatus: responseStatus, | ||
responseHeaders: responseHeaders, | ||
responseBody: responseBody, | ||
outbound: outbound | ||
) | ||
|
||
requestHead = nil | ||
|
||
if requestCount.increment() { | ||
logger.info( | ||
"Maximum number of requests reached, closing this connection", | ||
metadata: ["maxRequest": "2"] | ||
) | ||
break // this finishes handiling request on this connection | ||
} | ||
} | ||
} | ||
} | ||
} catch { | ||
logger.error("Hit error: \(error)") | ||
} | ||
} | ||
/// This function process the requests and return an hard-coded response (string or JSON depending on the mode). | ||
/// We ignore the requestBody. | ||
private func processRequest( | ||
requestHead: HTTPRequestHead, | ||
requestBody: ByteBuffer? | ||
) -> (HTTPResponseStatus, [(String, String)], String) { | ||
var responseStatus: HTTPResponseStatus = .ok | ||
var responseBody: String = "" | ||
var responseHeaders: [(String, String)] = [] | ||
|
||
logger.trace( | ||
"Processing request", | ||
metadata: ["VERB": "\(requestHead.method)", "URI": "\(requestHead.uri)"] | ||
) | ||
|
||
if requestHead.uri.hasSuffix("/next") { | ||
responseStatus = .accepted | ||
|
||
let requestId = UUID().uuidString | ||
switch self.mode { | ||
case .string: | ||
responseBody = "\"Seb\"" // must be a valid JSON document | ||
case .json: | ||
responseBody = "{ \"name\": \"Seb\", \"age\" : 52 }" | ||
} | ||
let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000) | ||
responseHeaders = [ | ||
(AmazonHeaders.requestID, requestId), | ||
(AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"), | ||
(AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"), | ||
(AmazonHeaders.deadline, String(deadline)), | ||
] | ||
} else if requestHead.uri.hasSuffix("/response") { | ||
responseStatus = .accepted | ||
} else if requestHead.uri.hasSuffix("/error") { | ||
responseStatus = .ok | ||
} else { | ||
responseStatus = .notFound | ||
} | ||
logger.trace("Returning response: \(responseStatus), \(responseHeaders), \(responseBody)") | ||
return (responseStatus, responseHeaders, responseBody) | ||
} | ||
|
||
private func sendResponse( | ||
responseStatus: HTTPResponseStatus, | ||
responseHeaders: [(String, String)], | ||
responseBody: String, | ||
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart> | ||
) async throws { | ||
var headers = HTTPHeaders(responseHeaders) | ||
headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)") | ||
headers.add(name: "KeepAlive", value: "timeout=1, max=2") | ||
|
||
logger.trace("Writing response head") | ||
try await outbound.write( | ||
HTTPServerResponsePart.head( | ||
HTTPResponseHead( | ||
version: .init(major: 1, minor: 1), // use HTTP 1.1 it keeps connection alive between requests | ||
status: responseStatus, | ||
headers: headers | ||
) | ||
) | ||
) | ||
logger.trace("Writing response body") | ||
try await outbound.write(HTTPServerResponsePart.body(.byteBuffer(ByteBuffer(string: responseBody)))) | ||
logger.trace("Writing response end") | ||
try await outbound.write(HTTPServerResponsePart.end(nil)) | ||
} | ||
|
||
private enum Mode: String { | ||
case string | ||
case json | ||
} | ||
|
||
private static func env(_ name: String) -> String? { | ||
guard let value = getenv(name) else { | ||
return nil | ||
} | ||
return String(cString: value) | ||
} | ||
|
||
private enum AmazonHeaders { | ||
static let requestID = "Lambda-Runtime-Aws-Request-Id" | ||
static let traceID = "Lambda-Runtime-Trace-Id" | ||
static let clientContext = "X-Amz-Client-Context" | ||
static let cognitoIdentity = "X-Amz-Cognito-Identity" | ||
static let deadline = "Lambda-Runtime-Deadline-Ms" | ||
static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" | ||
} | ||
|
||
private final class SharedCounter: Sendable { | ||
private let counterMutex = Mutex<Int>(0) | ||
private let maxValue: Int | ||
|
||
init(maxValue: Int) { | ||
self.maxValue = maxValue | ||
} | ||
func current() -> Int { | ||
counterMutex.withLock { $0 } | ||
} | ||
func increment() -> Bool { | ||
counterMutex.withLock { | ||
$0 += 1 | ||
return $0 >= maxValue | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.