Skip to content

Rate limit alerts by channel for task run alerts using generic cell rate algo #1679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions apps/webapp/app/env.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,39 @@ const EnvironmentSchema = z.object({
ALERT_SMTP_SECURE: z.coerce.boolean().optional(),
ALERT_SMTP_USER: z.string().optional(),
ALERT_SMTP_PASSWORD: z.string().optional(),
ALERT_RATE_LIMITER_EMISSION_INTERVAL: z.coerce.number().int().default(2_500),
ALERT_RATE_LIMITER_BURST_TOLERANCE: z.coerce.number().int().default(10_000),
ALERT_RATE_LIMITER_REDIS_HOST: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_HOST),
ALERT_RATE_LIMITER_REDIS_READER_HOST: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_READER_HOST),
ALERT_RATE_LIMITER_REDIS_READER_PORT: z.coerce
.number()
.optional()
.transform(
(v) =>
v ?? (process.env.REDIS_READER_PORT ? parseInt(process.env.REDIS_READER_PORT) : undefined)
),
ALERT_RATE_LIMITER_REDIS_PORT: z.coerce
.number()
.optional()
.transform((v) => v ?? (process.env.REDIS_PORT ? parseInt(process.env.REDIS_PORT) : undefined)),
ALERT_RATE_LIMITER_REDIS_USERNAME: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_USERNAME),
ALERT_RATE_LIMITER_REDIS_PASSWORD: z
.string()
.optional()
.transform((v) => v ?? process.env.REDIS_PASSWORD),
ALERT_RATE_LIMITER_REDIS_TLS_DISABLED: z
.string()
.default(process.env.REDIS_TLS_DISABLED ?? "false"),
ALERT_RATE_LIMITER_REDIS_CLUSTER_MODE_ENABLED: z.string().default("0"),

MAX_SEQUENTIAL_INDEX_FAILURE_COUNT: z.coerce.number().default(96),

Expand Down
171 changes: 171 additions & 0 deletions apps/webapp/app/v3/GCRARateLimiter.server.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import Redis, { Cluster } from "ioredis";

/**
* Options for configuring the RateLimiter.
*/
export interface GCRARateLimiterOptions {
/** An instance of ioredis. */
redis: Redis | Cluster;
/**
* A string prefix to namespace keys in Redis.
* Defaults to "ratelimit:".
*/
keyPrefix?: string;
/**
* The minimum interval between requests (the emission interval) in milliseconds.
* For example, 1000 ms for one request per second.
*/
emissionInterval: number;
/**
* The burst tolerance in milliseconds. This represents how much “credit” can be
* accumulated to allow short bursts beyond the average rate.
* For example, if you want to allow 3 requests in a burst with an emission interval of 1000 ms,
* you might set this to 3000.
*/
burstTolerance: number;
/**
* Expiration for the Redis key in milliseconds.
* Defaults to the larger of 60 seconds or (emissionInterval + burstTolerance).
*/
keyExpiration?: number;
}

/**
* The result of a rate limit check.
*/
export interface RateLimitResult {
/** Whether the request is allowed. */
allowed: boolean;
/**
* If not allowed, this is the number of milliseconds the caller should wait
* before retrying.
*/
retryAfter?: number;
}

/**
* A rate limiter using Redis and the Generic Cell Rate Algorithm (GCRA).
*
* The GCRA is implemented using a Lua script that runs atomically in Redis.
*
* When a request comes in, the algorithm:
* - Retrieves the current "Theoretical Arrival Time" (TAT) from Redis (or initializes it if missing).
* - If the current time is greater than or equal to the TAT, the request is allowed and the TAT is updated to now + emissionInterval.
* - Otherwise, if the current time plus the burst tolerance is at least the TAT, the request is allowed and the TAT is incremented.
* - If neither condition is met, the request is rejected and a Retry-After value is returned.
*/
export class GCRARateLimiter {
private redis: Redis | Cluster;
private keyPrefix: string;
private emissionInterval: number;
private burstTolerance: number;
private keyExpiration: number;

constructor(options: GCRARateLimiterOptions) {
this.redis = options.redis;
this.keyPrefix = options.keyPrefix || "gcra:ratelimit:";
this.emissionInterval = options.emissionInterval;
this.burstTolerance = options.burstTolerance;
// Default expiration: at least 60 seconds or the sum of emissionInterval and burstTolerance
this.keyExpiration =
options.keyExpiration || Math.max(60_000, this.emissionInterval + this.burstTolerance);

// Define a custom Redis command 'gcra' that implements the GCRA algorithm.
// Using defineCommand ensures the Lua script is loaded once and run atomically.
this.redis.defineCommand("gcra", {
numberOfKeys: 1,
lua: `
--[[
GCRA Lua script
KEYS[1] - The rate limit key (e.g. "ratelimit:<identifier>")
ARGV[1] - Current time in ms (number)
ARGV[2] - Emission interval in ms (number)
ARGV[3] - Burst tolerance in ms (number)
ARGV[4] - Key expiration in ms (number)

Returns: { allowedFlag, value }
allowedFlag: 1 if allowed, 0 if rate-limited.
value: 0 when allowed; if not allowed, the number of ms to wait.
]]--

local key = KEYS[1]
local now = tonumber(ARGV[1])
local emission_interval = tonumber(ARGV[2])
local burst_tolerance = tonumber(ARGV[3])
local expire = tonumber(ARGV[4])

-- Get the stored Theoretical Arrival Time (TAT) or default to 0.
local tat = tonumber(redis.call("GET", key) or 0)
if tat == 0 then
tat = now
end

local allowed, new_tat, retry_after

if now >= tat then
-- No delay: request is on schedule.
new_tat = now + emission_interval
allowed = true
elseif (now + burst_tolerance) >= tat then
-- Within burst capacity: allow request.
new_tat = tat + emission_interval
allowed = true
else
-- Request exceeds the allowed burst; calculate wait time.
allowed = false
retry_after = tat - (now + burst_tolerance)
end

if allowed then
redis.call("SET", key, new_tat, "PX", expire)
return {1, 0}
else
return {0, retry_after}
end
`,
});
}

/**
* Checks whether a request associated with the given identifier is allowed.
*
* @param identifier A unique string identifying the subject of rate limiting (e.g. user ID, IP address, or domain).
* @returns A promise that resolves to a RateLimitResult.
*
* @example
* const result = await rateLimiter.check('user:12345');
* if (!result.allowed) {
* // Tell the client to retry after result.retryAfter milliseconds.
* }
*/
async check(identifier: string): Promise<RateLimitResult> {
const key = `${this.keyPrefix}${identifier}`;
const now = Date.now();

try {
// Call the custom 'gcra' command.
// The script returns an array: [allowedFlag, value]
// - allowedFlag: 1 if allowed; 0 if rejected.
// - value: 0 when allowed; if rejected, the number of ms to wait before retrying.
// @ts-expect-error: The custom command is defined via defineCommand.
const result: [number, number] = await this.redis.gcra(
key,
now,
this.emissionInterval,
this.burstTolerance,
this.keyExpiration
);
const allowed = result[0] === 1;
if (allowed) {
return { allowed: true };
} else {
return { allowed: false, retryAfter: result[1] };
}
} catch (error) {
// In a production system you might log the error and either
// allow the request (fail open) or deny it (fail closed).
// Here we choose to propagate the error.
throw error;
}
}
}
30 changes: 30 additions & 0 deletions apps/webapp/app/v3/alertsRateLimiter.server.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { env } from "~/env.server";
import { createRedisClient } from "~/redis.server";
import { GCRARateLimiter } from "./GCRARateLimiter.server";
import { singleton } from "~/utils/singleton";
import { logger } from "~/services/logger.server";

export const alertsRateLimiter = singleton("alertsRateLimiter", initializeAlertsRateLimiter);

function initializeAlertsRateLimiter() {
const redis = createRedisClient("alerts:ratelimiter", {
keyPrefix: "alerts:ratelimiter:",
host: env.ALERT_RATE_LIMITER_REDIS_HOST,
port: env.ALERT_RATE_LIMITER_REDIS_PORT,
username: env.ALERT_RATE_LIMITER_REDIS_USERNAME,
password: env.ALERT_RATE_LIMITER_REDIS_PASSWORD,
tlsDisabled: env.ALERT_RATE_LIMITER_REDIS_TLS_DISABLED === "true",
clusterMode: env.ALERT_RATE_LIMITER_REDIS_CLUSTER_MODE_ENABLED === "1",
});

logger.debug(`🚦 Initializing alerts rate limiter at host ${env.ALERT_RATE_LIMITER_REDIS_HOST}`, {
emissionInterval: env.ALERT_RATE_LIMITER_EMISSION_INTERVAL,
burstTolerance: env.ALERT_RATE_LIMITER_BURST_TOLERANCE,
});

return new GCRARateLimiter({
redis,
emissionInterval: env.ALERT_RATE_LIMITER_EMISSION_INTERVAL,
burstTolerance: env.ALERT_RATE_LIMITER_BURST_TOLERANCE,
});
}
70 changes: 66 additions & 4 deletions apps/webapp/app/v3/services/alerts/deliverAlert.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
import { TaskRunError, createJsonErrorObject } from "@trigger.dev/core/v3";
import assertNever from "assert-never";
import { subtle } from "crypto";
import { Prisma, PrismaClientOrTransaction, prisma } from "~/db.server";
import { Prisma, prisma, PrismaClientOrTransaction } from "~/db.server";
import { env } from "~/env.server";
import {
OrgIntegrationRepository,
Expand All @@ -25,10 +25,12 @@ import { DeploymentPresenter } from "~/presenters/v3/DeploymentPresenter.server"
import { sendAlertEmail } from "~/services/email.server";
import { logger } from "~/services/logger.server";
import { decryptSecret } from "~/services/secrets/secretStore.server";
import { workerQueue } from "~/services/worker.server";
import { BaseService } from "../baseService.server";
import { FINAL_ATTEMPT_STATUSES } from "~/v3/taskStatus";
import { commonWorker } from "~/v3/commonWorker.server";
import { FINAL_ATTEMPT_STATUSES } from "~/v3/taskStatus";
import { BaseService } from "../baseService.server";
import { generateFriendlyId } from "~/v3/friendlyIdentifiers";
import { ProjectAlertType } from "@trigger.dev/database";
import { alertsRateLimiter } from "~/v3/alertsRateLimiter.server";

type FoundAlert = Prisma.Result<
typeof prisma.projectAlert,
Expand Down Expand Up @@ -1101,6 +1103,66 @@ export class DeliverAlertService extends BaseService {
availableAt: runAt,
});
}

static async createAndSendAlert(
{
channelId,
projectId,
environmentId,
alertType,
deploymentId,
taskRunId,
}: {
channelId: string;
projectId: string;
environmentId: string;
alertType: ProjectAlertType;
deploymentId?: string;
taskRunId?: string;
},
db: PrismaClientOrTransaction
) {
if (taskRunId) {
try {
const result = await alertsRateLimiter.check(channelId);

if (!result.allowed) {
logger.warn("[DeliverAlert] Rate limited", {
taskRunId,
environmentId,
alertType,
channelId,
result,
});

return;
}
} catch (error) {
logger.error("[DeliverAlert] Rate limiter error", {
taskRunId,
environmentId,
alertType,
channelId,
error,
});
}
}

const alert = await db.projectAlert.create({
data: {
friendlyId: generateFriendlyId("alert"),
channelId,
projectId,
environmentId,
status: "PENDING",
type: alertType,
workerDeploymentId: deploymentId,
taskRunId,
},
});

await DeliverAlertService.enqueue(alert.id);
}
}

function isWebAPIPlatformError(error: unknown): error is WebAPIPlatformError {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,16 @@ export class PerformDeploymentAlertsService extends BaseService {
deployment: WorkerDeployment,
alertType: ProjectAlertType
) {
const alert = await this._prisma.projectAlert.create({
data: {
friendlyId: generateFriendlyId("alert"),
await DeliverAlertService.createAndSendAlert(
{
channelId: alertChannel.id,
projectId: deployment.projectId,
environmentId: deployment.environmentId,
status: "PENDING",
type: alertType,
workerDeploymentId: deployment.id,
alertType,
deploymentId: deployment.id,
},
});

await DeliverAlertService.enqueue(alert.id);
this._prisma
);
}

static async enqueue(deploymentId: string, runAt?: Date) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,16 @@ export class PerformTaskRunAlertsService extends BaseService {
}

async #createAndSendAlert(alertChannel: ProjectAlertChannel, run: FoundRun) {
const alert = await this._prisma.projectAlert.create({
data: {
friendlyId: generateFriendlyId("alert"),
await DeliverAlertService.createAndSendAlert(
{
channelId: alertChannel.id,
projectId: run.projectId,
environmentId: run.runtimeEnvironmentId,
status: "PENDING",
type: "TASK_RUN",
alertType: "TASK_RUN",
taskRunId: run.id,
},
});

await DeliverAlertService.enqueue(alert.id);
this._prisma
);
}

static async enqueue(runId: string, runAt?: Date) {
Expand Down
Loading