Skip to content

Commit

Permalink
[test] Update mock server for Swift 6 compliance (#463)
Browse files Browse the repository at this point in the history
Rewrite to MockServer (used only for performance testing and part of a
separate target) to comply with Swift 6 language mode and concurrency.

The new MockServer now uses `NIOAsyncChannel` and structured
concurrency.

Instead of adding support to
[MAX_REQUEST](https://github.com/swift-server/swift-aws-lambda-runtime/blob/11756b4e00ca75894826b41666bdae506b6eb496/Sources/AWSLambdaRuntimeCore/LambdaConfiguration.swift#L53)
environment variable like v1 did, we implemented support for
`MAX_REQUEST` environment variable in the MockServer itself. It closes
the connection and shutdown the server after servicing MAX_INVOCATIONS
Lambda requests). This allow to add the MAX_REQUEST penalty on the
MockServer and not on the LambdaRuntimeClient.

However, currently, the LambdaRuntimeClient does not shutdown when the
MockServer ends. I created
#465 to
track this issue.

See #377
  • Loading branch information
sebsto authored Jan 22, 2025
1 parent ed84609 commit 02821fe
Show file tree
Hide file tree
Showing 3 changed files with 301 additions and 180 deletions.
6 changes: 3 additions & 3 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ let package = Package(
.library(name: "AWSLambdaTesting", targets: ["AWSLambdaTesting"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-nio.git", from: "2.76.0"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.77.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.5.4"),
],
targets: [
Expand Down Expand Up @@ -89,11 +89,11 @@ let package = Package(
.executableTarget(
name: "MockServer",
dependencies: [
.product(name: "Logging", package: "swift-log"),
.product(name: "NIOHTTP1", package: "swift-nio"),
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOPosix", package: "swift-nio"),
],
swiftSettings: [.swiftLanguageMode(.v5)]
]
),
]
)
298 changes: 298 additions & 0 deletions Sources/MockServer/MockHTTPServer.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftAWSLambdaRuntime open source project
//
// Copyright (c) 2017-2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import Logging
import NIOCore
import NIOHTTP1
import NIOPosix
import Synchronization

// for UUID and Date
#if canImport(FoundationEssentials)
import FoundationEssentials
#else
import Foundation
#endif

@main
struct HttpServer {
/// The server's host. (default: 127.0.0.1)
private let host: String
/// The server's port. (default: 7000)
private let port: Int
/// The server's event loop group. (default: MultiThreadedEventLoopGroup.singleton)
private let eventLoopGroup: MultiThreadedEventLoopGroup
/// the mode. Are we mocking a server for a Lambda function that expects a String or a JSON document? (default: string)
private let mode: Mode
/// the number of connections this server must accept before shutting down (default: 1)
private let maxInvocations: Int
/// the logger (control verbosity with LOG_LEVEL environment variable)
private let logger: Logger

static func main() async throws {
var log = Logger(label: "MockServer")
log.logLevel = env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info

let server = HttpServer(
host: env("HOST") ?? "127.0.0.1",
port: env("PORT").flatMap(Int.init) ?? 7000,
eventLoopGroup: .singleton,
mode: env("MODE").flatMap(Mode.init) ?? .string,
maxInvocations: env("MAX_INVOCATIONS").flatMap(Int.init) ?? 1,
logger: log
)
try await server.run()
}

/// This method starts the server and handles one unique incoming connections
/// The Lambda function will send two HTTP requests over this connection: one for the next invocation and one for the response.
private func run() async throws {
let channel = try await ServerBootstrap(group: self.eventLoopGroup)
.serverChannelOption(.backlog, value: 256)
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(.maxMessagesPerRead, value: 1)
.bind(
host: self.host,
port: self.port
) { channel in
channel.eventLoop.makeCompletedFuture {

try channel.pipeline.syncOperations.configureHTTPServerPipeline(
withErrorHandling: true
)

return try NIOAsyncChannel(
wrappingChannelSynchronously: channel,
configuration: NIOAsyncChannel.Configuration(
inboundType: HTTPServerRequestPart.self,
outboundType: HTTPServerResponsePart.self
)
)
}
}

logger.info(
"Server started and listening",
metadata: [
"host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")",
"port": "\(channel.channel.localAddress?.port ?? 0)",
"maxInvocations": "\(self.maxInvocations)",
]
)

// This counter is used to track the number of incoming connections.
// This mock servers accepts n TCP connection then shutdowns
let connectionCounter = SharedCounter(maxValue: self.maxInvocations)

// We are handling each incoming connection in a separate child task. It is important
// to use a discarding task group here which automatically discards finished child tasks.
// A normal task group retains all child tasks and their outputs in memory until they are
// consumed by iterating the group or by exiting the group. Since, we are never consuming
// the results of the group we need the group to automatically discard them; otherwise, this
// would result in a memory leak over time.
try await withThrowingDiscardingTaskGroup { group in
try await channel.executeThenClose { inbound in
for try await connectionChannel in inbound {

let counter = connectionCounter.current()
logger.trace("Handling new connection", metadata: ["connectionNumber": "\(counter)"])

group.addTask {
await self.handleConnection(channel: connectionChannel)
logger.trace("Done handling connection", metadata: ["connectionNumber": "\(counter)"])
}

if connectionCounter.increment() {
logger.info(
"Maximum number of connections reached, shutting down after current connection",
metadata: ["maxConnections": "\(self.maxInvocations)"]
)
break // this causes the server to shutdown after handling the connection
}
}
}
}
logger.info("Server shutting down")
}

/// This method handles a single connection by responsing hard coded value to a Lambda function request.
/// It handles two requests: one for the next invocation and one for the response.
/// when the maximum number of requests is reached, it closes the connection.
private func handleConnection(
channel: NIOAsyncChannel<HTTPServerRequestPart, HTTPServerResponsePart>
) async {

var requestHead: HTTPRequestHead!
var requestBody: ByteBuffer?

// each Lambda invocation results in TWO HTTP requests (next and response)
let requestCount = SharedCounter(maxValue: 2)

// Note that this method is non-throwing and we are catching any error.
// We do this since we don't want to tear down the whole server when a single connection
// encounters an error.
do {
try await channel.executeThenClose { inbound, outbound in
for try await inboundData in inbound {
let requestNumber = requestCount.current()
logger.trace("Handling request", metadata: ["requestNumber": "\(requestNumber)"])

if case .head(let head) = inboundData {
logger.trace("Received request head", metadata: ["head": "\(head)"])
requestHead = head
}
if case .body(let body) = inboundData {
logger.trace("Received request body", metadata: ["body": "\(body)"])
requestBody = body
}
if case .end(let end) = inboundData {
logger.trace("Received request end", metadata: ["end": "\(String(describing: end))"])

precondition(requestHead != nil, "Received .end without .head")
let (responseStatus, responseHeaders, responseBody) = self.processRequest(
requestHead: requestHead,
requestBody: requestBody
)

try await self.sendResponse(
responseStatus: responseStatus,
responseHeaders: responseHeaders,
responseBody: responseBody,
outbound: outbound
)

requestHead = nil

if requestCount.increment() {
logger.info(
"Maximum number of requests reached, closing this connection",
metadata: ["maxRequest": "2"]
)
break // this finishes handiling request on this connection
}
}
}
}
} catch {
logger.error("Hit error: \(error)")
}
}
/// This function process the requests and return an hard-coded response (string or JSON depending on the mode).
/// We ignore the requestBody.
private func processRequest(
requestHead: HTTPRequestHead,
requestBody: ByteBuffer?
) -> (HTTPResponseStatus, [(String, String)], String) {
var responseStatus: HTTPResponseStatus = .ok
var responseBody: String = ""
var responseHeaders: [(String, String)] = []

logger.trace(
"Processing request",
metadata: ["VERB": "\(requestHead.method)", "URI": "\(requestHead.uri)"]
)

if requestHead.uri.hasSuffix("/next") {
responseStatus = .accepted

let requestId = UUID().uuidString
switch self.mode {
case .string:
responseBody = "\"Seb\"" // must be a valid JSON document
case .json:
responseBody = "{ \"name\": \"Seb\", \"age\" : 52 }"
}
let deadline = Int64(Date(timeIntervalSinceNow: 60).timeIntervalSince1970 * 1000)
responseHeaders = [
(AmazonHeaders.requestID, requestId),
(AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:123456789012:function:custom-runtime"),
(AmazonHeaders.traceID, "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419;Sampled=1"),
(AmazonHeaders.deadline, String(deadline)),
]
} else if requestHead.uri.hasSuffix("/response") {
responseStatus = .accepted
} else if requestHead.uri.hasSuffix("/error") {
responseStatus = .ok
} else {
responseStatus = .notFound
}
logger.trace("Returning response: \(responseStatus), \(responseHeaders), \(responseBody)")
return (responseStatus, responseHeaders, responseBody)
}

private func sendResponse(
responseStatus: HTTPResponseStatus,
responseHeaders: [(String, String)],
responseBody: String,
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>
) async throws {
var headers = HTTPHeaders(responseHeaders)
headers.add(name: "Content-Length", value: "\(responseBody.utf8.count)")
headers.add(name: "KeepAlive", value: "timeout=1, max=2")

logger.trace("Writing response head")
try await outbound.write(
HTTPServerResponsePart.head(
HTTPResponseHead(
version: .init(major: 1, minor: 1), // use HTTP 1.1 it keeps connection alive between requests
status: responseStatus,
headers: headers
)
)
)
logger.trace("Writing response body")
try await outbound.write(HTTPServerResponsePart.body(.byteBuffer(ByteBuffer(string: responseBody))))
logger.trace("Writing response end")
try await outbound.write(HTTPServerResponsePart.end(nil))
}

private enum Mode: String {
case string
case json
}

private static func env(_ name: String) -> String? {
guard let value = getenv(name) else {
return nil
}
return String(cString: value)
}

private enum AmazonHeaders {
static let requestID = "Lambda-Runtime-Aws-Request-Id"
static let traceID = "Lambda-Runtime-Trace-Id"
static let clientContext = "X-Amz-Client-Context"
static let cognitoIdentity = "X-Amz-Cognito-Identity"
static let deadline = "Lambda-Runtime-Deadline-Ms"
static let invokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn"
}

private final class SharedCounter: Sendable {
private let counterMutex = Mutex<Int>(0)
private let maxValue: Int

init(maxValue: Int) {
self.maxValue = maxValue
}
func current() -> Int {
counterMutex.withLock { $0 }
}
func increment() -> Bool {
counterMutex.withLock {
$0 += 1
return $0 >= maxValue
}
}
}
}
Loading

0 comments on commit 02821fe

Please sign in to comment.