Skip to content

feat(control-plane): add support for handling multiple events in a single invocation #4603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ Join our discord community via [this invite link](https://discord.gg/bxgXW8jJGh)
| <a name="input_key_name"></a> [key\_name](#input\_key\_name) | Key pair name | `string` | `null` | no |
| <a name="input_kms_key_arn"></a> [kms\_key\_arn](#input\_kms\_key\_arn) | Optional CMK Key ARN to be used for Parameter Store. This key must be in the current account. | `string` | `null` | no |
| <a name="input_lambda_architecture"></a> [lambda\_architecture](#input\_lambda\_architecture) | AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions. | `string` | `"arm64"` | no |
| <a name="input_lambda_event_source_mapping_batch_size"></a> [lambda\_event\_source\_mapping\_batch\_size](#input\_lambda\_event\_source\_mapping\_batch\_size) | Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used. | `number` | `10` | no |
| <a name="input_lambda_event_source_mapping_maximum_batching_window_in_seconds"></a> [lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds](#input\_lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds) | Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch\_size is greater than 10. Defaults to 0. | `number` | `0` | no |
| <a name="input_lambda_principals"></a> [lambda\_principals](#input\_lambda\_principals) | (Optional) add extra principals to the role created for execution of the lambda, e.g. for local testing. | <pre>list(object({<br/> type = string<br/> identifiers = list(string)<br/> }))</pre> | `[]` | no |
| <a name="input_lambda_runtime"></a> [lambda\_runtime](#input\_lambda\_runtime) | AWS Lambda runtime. | `string` | `"nodejs22.x"` | no |
| <a name="input_lambda_s3_bucket"></a> [lambda\_s3\_bucket](#input\_lambda\_s3\_bucket) | S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly. | `string` | `null` | no |
Expand Down
171 changes: 147 additions & 24 deletions lambdas/functions/control-plane/src/lambda.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,33 @@ vi.mock('@aws-github-runner/aws-powertools-util');
vi.mock('@aws-github-runner/aws-ssm-util');

describe('Test scale up lambda wrapper.', () => {
it('Do not handle multiple record sets.', async () => {
await testInvalidRecords([sqsRecord, sqsRecord]);
it('Do not handle empty record sets.', async () => {
const sqsEventMultipleRecords: SQSEvent = {
Records: [],
};

await expect(scaleUpHandler(sqsEventMultipleRecords, context)).resolves.not.toThrow();
});

it('Do not handle empty record sets.', async () => {
await testInvalidRecords([]);
it('Ignores non-sqs event sources.', async () => {
const record = {
...sqsRecord,
eventSource: 'aws:non-sqs',
};

const sqsEventMultipleRecordsNonSQS: SQSEvent = {
Records: [record],
};

await expect(scaleUpHandler(sqsEventMultipleRecordsNonSQS, context)).resolves.not.toThrow();
expect(scaleUp).toHaveBeenCalledWith([]);
});

it('Scale without error should resolve.', async () => {
const mock = vi.fn(scaleUp);
mock.mockImplementation(() => {
return new Promise((resolve) => {
resolve();
resolve([]);
});
});
await expect(scaleUpHandler(sqsEvent, context)).resolves.not.toThrow();
Expand All @@ -104,28 +118,137 @@ describe('Test scale up lambda wrapper.', () => {
vi.mocked(scaleUp).mockImplementation(mock);
await expect(scaleUpHandler(sqsEvent, context)).rejects.toThrow(error);
});
});

async function testInvalidRecords(sqsRecords: SQSRecord[]) {
const mock = vi.fn(scaleUp);
const logWarnSpy = vi.spyOn(logger, 'warn');
mock.mockImplementation(() => {
return new Promise((resolve) => {
resolve();
describe('Batch processing', () => {
beforeEach(() => {
vi.clearAllMocks();
});

const createMultipleRecords = (count: number, eventSource = 'aws:sqs'): SQSRecord[] => {
return Array.from({ length: count }, (_, i) => ({
...sqsRecord,
eventSource,
messageId: `message-${i}`,
body: JSON.stringify({
...body,
id: i + 1,
}),
}));
};

it('Should handle multiple SQS records in a single invocation', async () => {
const records = createMultipleRecords(3);
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.resolve([]));
vi.mocked(scaleUp).mockImplementation(mock);

await expect(scaleUpHandler(multiRecordEvent, context)).resolves.not.toThrow();
expect(scaleUp).toHaveBeenCalledWith(
expect.arrayContaining([
expect.objectContaining({ messageId: 'message-0' }),
expect.objectContaining({ messageId: 'message-1' }),
expect.objectContaining({ messageId: 'message-2' }),
]),
);
});

it('Should return batch item failures for rejected messages', async () => {
const records = createMultipleRecords(3);
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.resolve(['message-1', 'message-2']));
vi.mocked(scaleUp).mockImplementation(mock);

const result = await scaleUpHandler(multiRecordEvent, context);
expect(result).toEqual({
batchItemFailures: [{ itemIdentifier: 'message-1' }, { itemIdentifier: 'message-2' }],
});
});

it('Should filter out non-SQS event sources', async () => {
const sqsRecords = createMultipleRecords(2, 'aws:sqs');
const nonSqsRecords = createMultipleRecords(1, 'aws:sns');
const mixedEvent: SQSEvent = {
Records: [...sqsRecords, ...nonSqsRecords],
};

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.resolve([]));
vi.mocked(scaleUp).mockImplementation(mock);

await scaleUpHandler(mixedEvent, context);
expect(scaleUp).toHaveBeenCalledWith(
expect.arrayContaining([
expect.objectContaining({ messageId: 'message-0' }),
expect.objectContaining({ messageId: 'message-1' }),
]),
);
expect(scaleUp).not.toHaveBeenCalledWith(
expect.arrayContaining([expect.objectContaining({ messageId: 'message-2' })]),
);
});

it('Should sort messages by retry count', async () => {
const records = [
{
...sqsRecord,
messageId: 'high-retry',
body: JSON.stringify({ ...body, retryCounter: 5 }),
},
{
...sqsRecord,
messageId: 'low-retry',
body: JSON.stringify({ ...body, retryCounter: 1 }),
},
{
...sqsRecord,
messageId: 'no-retry',
body: JSON.stringify({ ...body }),
},
];
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation((messages) => {
// Verify messages are sorted by retry count (ascending)
expect(messages[0].messageId).toBe('no-retry');
expect(messages[1].messageId).toBe('low-retry');
expect(messages[2].messageId).toBe('high-retry');
return Promise.resolve([]);
});
vi.mocked(scaleUp).mockImplementation(mock);

await scaleUpHandler(multiRecordEvent, context);
});

it('Should return all failed messages when scaleUp throws non-ScaleError', async () => {
const records = createMultipleRecords(2);
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.reject(new Error('Generic error')));
vi.mocked(scaleUp).mockImplementation(mock);

const result = await scaleUpHandler(multiRecordEvent, context);
expect(result).toEqual({ batchItemFailures: [] });
});

it('Should throw when scaleUp throws ScaleError', async () => {
const records = createMultipleRecords(2);
const multiRecordEvent: SQSEvent = { Records: records };

const error = new ScaleError('Critical scaling error');
const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.reject(error));
vi.mocked(scaleUp).mockImplementation(mock);

await expect(scaleUpHandler(multiRecordEvent, context)).rejects.toThrow(error);
});
});
const sqsEventMultipleRecords: SQSEvent = {
Records: sqsRecords,
};

await expect(scaleUpHandler(sqsEventMultipleRecords, context)).resolves.not.toThrow();

expect(logWarnSpy).toHaveBeenCalledWith(
expect.stringContaining(
'Event ignored, only one record at the time can be handled, ensure the lambda batch size is set to 1.',
),
);
}
});

describe('Test scale down lambda wrapper.', () => {
it('Scaling down no error.', async () => {
Expand Down
62 changes: 50 additions & 12 deletions lambdas/functions/control-plane/src/lambda.ts
Original file line number Diff line number Diff line change
@@ -1,34 +1,72 @@
import middy from '@middy/core';
import { logger, setContext } from '@aws-github-runner/aws-powertools-util';
import { captureLambdaHandler, tracer } from '@aws-github-runner/aws-powertools-util';
import { Context, SQSEvent } from 'aws-lambda';
import { Context, type SQSBatchItemFailure, type SQSBatchResponse, SQSEvent } from 'aws-lambda';

import { PoolEvent, adjust } from './pool/pool';
import ScaleError from './scale-runners/ScaleError';
import { scaleDown } from './scale-runners/scale-down';
import { scaleUp } from './scale-runners/scale-up';
import { type ActionRequestMessage, type ActionRequestMessageSQS, scaleUp } from './scale-runners/scale-up';
import { SSMCleanupOptions, cleanSSMTokens } from './scale-runners/ssm-housekeeper';
import { checkAndRetryJob } from './scale-runners/job-retry';

export async function scaleUpHandler(event: SQSEvent, context: Context): Promise<void> {
export async function scaleUpHandler(event: SQSEvent, context: Context): Promise<SQSBatchResponse> {
setContext(context, 'lambda.ts');
logger.logEventIfEnabled(event);

if (event.Records.length !== 1) {
logger.warn('Event ignored, only one record at the time can be handled, ensure the lambda batch size is set to 1.');
return Promise.resolve();
// Group the messages by their event source. We're only interested in
// `aws:sqs`-originated messages.
const groupedEvents = new Map<string, ActionRequestMessageSQS[]>();
for (const { body, eventSource, messageId } of event.Records) {
const group = groupedEvents.get(eventSource) || [];
const payload = JSON.parse(body) as ActionRequestMessage;

if (group.length === 0) {
groupedEvents.set(eventSource, group);
}

groupedEvents.get(eventSource)?.push({
...payload,
messageId,
});
}

for (const [eventSource, messages] of groupedEvents.entries()) {
if (eventSource === 'aws:sqs') {
continue;
}

logger.warn('Ignoring non-sqs event source', { eventSource, messages });
}

const sqsMessages = groupedEvents.get('aws:sqs') ?? [];

// Sort messages by their retry count, so that we retry the same messages if
// there's a persistent failure. This should cause messages to be dropped
// quicker than if we retried in an arbitrary order.
sqsMessages.sort((l, r) => {
return (l.retryCounter ?? 0) - (r.retryCounter ?? 0);
});

const batchItemFailures: SQSBatchItemFailure[] = [];

try {
await scaleUp(event.Records[0].eventSource, JSON.parse(event.Records[0].body));
return Promise.resolve();
const rejectedMessageIds = await scaleUp(sqsMessages);

for (const messageId of rejectedMessageIds) {
batchItemFailures.push({
itemIdentifier: messageId,
});
}

return { batchItemFailures };
} catch (e) {
if (e instanceof ScaleError) {
return Promise.reject(e);
} else {
logger.warn(`Ignoring error: ${e}`);
return Promise.resolve();
throw e;
}

logger.warn(`Will retry error: ${e}`);
return { batchItemFailures };
}
}

Expand Down
42 changes: 32 additions & 10 deletions lambdas/functions/control-plane/src/local.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import { logger } from '@aws-github-runner/aws-powertools-util';

import { ActionRequestMessage, scaleUp } from './scale-runners/scale-up';
import { scaleUpHandler } from './lambda';
import { Context, SQSEvent } from 'aws-lambda';

const sqsEvent = {
const sqsEvent: SQSEvent = {
Records: [
{
messageId: 'e8d74d08-644e-42ca-bf82-a67daa6c4dad',
receiptHandle:
// eslint-disable-next-line max-len
'AQEBCpLYzDEKq4aKSJyFQCkJduSKZef8SJVOperbYyNhXqqnpFG5k74WygVAJ4O0+9nybRyeOFThvITOaS21/jeHiI5fgaM9YKuI0oGYeWCIzPQsluW5CMDmtvqv1aA8sXQ5n2x0L9MJkzgdIHTC3YWBFLQ2AxSveOyIHwW+cHLIFCAcZlOaaf0YtaLfGHGkAC4IfycmaijV8NSlzYgDuxrC9sIsWJ0bSvk5iT4ru/R4+0cjm7qZtGlc04k9xk5Fu6A+wRxMaIyiFRY+Ya19ykcevQldidmEjEWvN6CRToLgclk=',
body: {
body: JSON.stringify({
repositoryName: 'self-hosted',
repositoryOwner: 'test-runners',
eventType: 'workflow_job',
id: 987654,
installationId: 123456789,
},
}),
attributes: {
ApproximateReceiveCount: '1',
SentTimestamp: '1626450047230',
Expand All @@ -34,12 +34,34 @@ const sqsEvent = {
],
};

const context: Context = {
awsRequestId: '1',
callbackWaitsForEmptyEventLoop: false,
functionName: '',
functionVersion: '',
getRemainingTimeInMillis: () => 0,
invokedFunctionArn: '',
logGroupName: '',
logStreamName: '',
memoryLimitInMB: '',
done: () => {
return;
},
fail: () => {
return;
},
succeed: () => {
return;
},
};

export function run(): void {
scaleUp(sqsEvent.Records[0].eventSource, sqsEvent.Records[0].body as ActionRequestMessage)
.then()
.catch((e) => {
logger.error(e);
});
try {
scaleUpHandler(sqsEvent, context);
} catch (e: unknown) {
const message = e instanceof Error ? e.message : `${e}`;
logger.error(message, e instanceof Error ? { error: e } : {});
}
}

run();
Loading