diff --git a/apps/webapp/app/presenters/v3/QueueListPresenter.server.ts b/apps/webapp/app/presenters/v3/QueueListPresenter.server.ts index 0fe9e3f3652..a24b0904637 100644 --- a/apps/webapp/app/presenters/v3/QueueListPresenter.server.ts +++ b/apps/webapp/app/presenters/v3/QueueListPresenter.server.ts @@ -117,6 +117,7 @@ export class QueueListPresenter extends BasePresenter { concurrencyLimitBase: true, concurrencyLimitOverriddenAt: true, concurrencyLimitOverriddenBy: true, + rateLimit: true, type: true, paused: true, }, @@ -163,6 +164,7 @@ export class QueueListPresenter extends BasePresenter { concurrencyLimitOverriddenBy: queue.concurrencyLimitOverriddenBy ? overriddenByMap.get(queue.concurrencyLimitOverriddenBy) ?? null : null, + rateLimit: queue.rateLimit, paused: queue.paused, }) ); diff --git a/apps/webapp/app/presenters/v3/QueueRetrievePresenter.server.ts b/apps/webapp/app/presenters/v3/QueueRetrievePresenter.server.ts index bd885ea738b..9f6264f3798 100644 --- a/apps/webapp/app/presenters/v3/QueueRetrievePresenter.server.ts +++ b/apps/webapp/app/presenters/v3/QueueRetrievePresenter.server.ts @@ -112,6 +112,7 @@ export class QueueRetrievePresenter extends BasePresenter { concurrencyLimitBase: queue.concurrencyLimitBase ?? null, concurrencyLimitOverriddenAt: queue.concurrencyLimitOverriddenAt ?? null, concurrencyLimitOverriddenBy: queue.concurrencyLimitOverriddenBy ?? null, + rateLimit: queue.rateLimit, paused: queue.paused, }), }; @@ -144,6 +145,7 @@ export function toQueueItem(data: { concurrencyLimitBase: number | null; concurrencyLimitOverriddenAt: Date | null; concurrencyLimitOverriddenBy: User | null; + rateLimit: any; paused: boolean; }): QueueItem & { releaseConcurrencyOnWaitpoint: boolean } { return { @@ -162,6 +164,7 @@ export function toQueueItem(data: { overriddenBy: toQueueConcurrencyOverriddenBy(data.concurrencyLimitOverriddenBy), overriddenAt: data.concurrencyLimitOverriddenAt, }, + rateLimit: data.rateLimit as any, // TODO: This needs to be removed but keeping this here for now to avoid breaking existing clients releaseConcurrencyOnWaitpoint: true, }; diff --git a/apps/webapp/app/presenters/v3/TaskPresenter.server.ts b/apps/webapp/app/presenters/v3/TaskPresenter.server.ts index 671e92a445d..12f69c27cd5 100644 --- a/apps/webapp/app/presenters/v3/TaskPresenter.server.ts +++ b/apps/webapp/app/presenters/v3/TaskPresenter.server.ts @@ -55,6 +55,7 @@ export class TaskPresenter { }, }, }, + queueConfig: true, }, where: { friendlyId: taskFriendlyId, diff --git a/apps/webapp/app/routes/_app.orgs.$organizationSlug.projects.$projectParam.env.$envParam.queues/route.tsx b/apps/webapp/app/routes/_app.orgs.$organizationSlug.projects.$projectParam.env.$envParam.queues/route.tsx index debab4683ce..4cbc260a232 100644 --- a/apps/webapp/app/routes/_app.orgs.$organizationSlug.projects.$projectParam.env.$envParam.queues/route.tsx +++ b/apps/webapp/app/routes/_app.orgs.$organizationSlug.projects.$projectParam.env.$envParam.queues/route.tsx @@ -5,7 +5,9 @@ import { ChatBubbleLeftEllipsisIcon, PauseIcon, PlayIcon, + PlusIcon, RectangleStackIcon, + XMarkIcon, } from "@heroicons/react/20/solid"; import { DialogClose } from "@radix-ui/react-dialog"; import { Form, useNavigation, useSearchParams, type MetaFunction } from "@remix-run/react"; @@ -28,9 +30,12 @@ import { Badge } from "~/components/primitives/Badge"; import { Button, LinkButton, type ButtonVariant } from "~/components/primitives/Buttons"; import { Callout } from "~/components/primitives/Callout"; import { Dialog, DialogContent, DialogHeader, DialogTrigger } from "~/components/primitives/Dialog"; +import { Fieldset } from "~/components/primitives/Fieldset"; import { FormButtons } from "~/components/primitives/FormButtons"; import { Header3 } from "~/components/primitives/Headers"; import { Input } from "~/components/primitives/Input"; +import { InputGroup } from "~/components/primitives/InputGroup"; +import { Label } from "~/components/primitives/Label"; import { SearchInput } from "~/components/primitives/SearchInput"; import { NavBar, PageAccessories, PageTitle } from "~/components/primitives/PageHeader"; import { PaginationControls } from "~/components/primitives/Pagination"; @@ -63,6 +68,7 @@ import { redirectWithErrorMessage, redirectWithSuccessMessage } from "~/models/m import { findProjectBySlug } from "~/models/project.server"; import { findEnvironmentBySlug } from "~/models/runtimeEnvironment.server"; import { getUserById } from "~/models/user.server"; +import { prisma } from "~/db.server"; import { EnvironmentQueuePresenter } from "~/presenters/v3/EnvironmentQueuePresenter.server"; import { QueueListPresenter } from "~/presenters/v3/QueueListPresenter.server"; import { requireUserId } from "~/services/session.server"; @@ -75,6 +81,7 @@ import { v3RunsPath, } from "~/utils/pathBuilder"; import { concurrencySystem } from "~/v3/services/concurrencySystemInstance.server"; +import { rateLimitSystem } from "~/v3/services/rateLimitSystemInstance.server"; import { PauseEnvironmentService } from "~/v3/services/pauseEnvironment.server"; import { PauseQueueService } from "~/v3/services/pauseQueue.server"; import { useCurrentPlan } from "../_app.orgs.$organizationSlug/route"; @@ -221,6 +228,7 @@ export const action = async ({ request, params }: ActionFunctionArgs) => { case "queue-override": { const friendlyId = formData.get("friendlyId"); const concurrencyLimit = formData.get("concurrencyLimit"); + const rateLimitsJson = formData.get("rateLimits"); if (!friendlyId) { return redirectWithErrorMessage(redirectPath, request, "Queue ID is required"); @@ -259,10 +267,31 @@ export const action = async ({ request, params }: ActionFunctionArgs) => { ); } + if (rateLimitsJson) { + try { + const rateLimits = JSON.parse(rateLimitsJson.toString()) as Array<{ limit: number; window: number }>; + const queue = await prisma.taskQueue.findFirst({ + where: { friendlyId: friendlyId.toString(), runtimeEnvironmentId: environment.id }, + }); + if (queue) { + await rateLimitSystem.overrideQueueRateLimit(environment, queue.name, rateLimits); + } + } catch (e) { + return redirectWithErrorMessage(redirectPath, request, "Invalid rate limits format"); + } + } else { + const queue = await prisma.taskQueue.findFirst({ + where: { friendlyId: friendlyId.toString(), runtimeEnvironmentId: environment.id }, + }); + if (queue) { + await rateLimitSystem.resetQueueRateLimit(environment, queue.name); + } + } + return redirectWithSuccessMessage( redirectPath, request, - "Queue concurrency limit overridden" + "Queue limits overridden" ); } case "queue-remove-override": { @@ -285,7 +314,14 @@ export const action = async ({ request, params }: ActionFunctionArgs) => { ); } - return redirectWithSuccessMessage(redirectPath, request, "Queue concurrency limit reset"); + const queue = await prisma.taskQueue.findFirst({ + where: { friendlyId: friendlyId.toString(), runtimeEnvironmentId: environment.id }, + }); + if (queue) { + await rateLimitSystem.resetQueueRateLimit(environment, queue.name); + } + + return redirectWithSuccessMessage(redirectPath, request, "Queue limits reset"); } default: return redirectWithErrorMessage(redirectPath, request, "Something went wrong"); @@ -451,6 +487,7 @@ export default function Page() { Queued Running Limit + Rate Limit {limit} + + {queue.rateLimit && queue.rateLimit.length > 0 ? ( +
+ {queue.rateLimit.map((rl: any, i: number) => ( + + {rl.limit} / {rl.window}s + + ))} +
+ ) : ( + - + )} +
- @@ -880,7 +936,7 @@ function QueuePauseResumeButton({ ); } -function QueueOverrideConcurrencyButton({ +function QueueOverrideLimitsButton({ queue, environmentConcurrencyLimit, }: { @@ -892,8 +948,11 @@ function QueueOverrideConcurrencyButton({ const [concurrencyLimit, setConcurrencyLimit] = useState( queue.concurrencyLimit?.toString() ?? environmentConcurrencyLimit.toString() ); + const [rateLimits, setRateLimits] = useState>( + queue.rateLimit && queue.rateLimit.length > 0 ? queue.rateLimit : [{ limit: 0, window: 0 }] + ); - const isOverridden = !!queue.concurrency?.overriddenAt; + const isOverridden = !!queue.concurrency?.overriddenAt || (queue.rateLimit && queue.rateLimit.length > 0); const currentLimit = queue.concurrencyLimit ?? environmentConcurrencyLimit; useEffect(() => { @@ -915,34 +974,35 @@ function QueueOverrideConcurrencyButton({ title={isOverridden ? "Edit override…" : "Override limit…"} /> - - - {isOverridden ? "Edit concurrency override" : "Override concurrency limit"} + + + {isOverridden ? "Edit limits override" : "Override limits"} -
- {isOverridden ? ( - - This queue's concurrency limit is currently overridden to {currentLimit}. - {typeof queue.concurrency?.base === "number" && - ` The original limit set in code was ${queue.concurrency.base}.`}{" "} - You can update the override or remove it to restore the{" "} - {typeof queue.concurrency?.base === "number" - ? "limit set in code" - : "environment concurrency limit"} - . - - ) : ( - - Override this queue's concurrency limit. The current limit is {currentLimit}, which is - set {queue.concurrencyLimit !== null ? "in code" : "by the environment"}. - - )} -
setIsOpen(false)} className="space-y-3"> - -
- + setIsOpen(false)}> + + rl.limit > 0 && rl.window > 0))} /> + +
+ {isOverridden ? ( + + This queue's limits are currently overridden. + {typeof queue.concurrency?.base === "number" && + ` The original concurrency limit set in code was ${queue.concurrency.base}.`}{" "} + You can update the override or remove it to restore the{" "} + {typeof queue.concurrency?.base === "number" + ? "limit set in code" + : "environment concurrency limit"} + . + + ) : ( + + Override this queue's limits. The current concurrency limit is {currentLimit}, which is + set {queue.concurrencyLimit !== null ? "in code" : "by the environment"}. + + )} + + + -
+ - +
+ + +
+ {rateLimits.map((rl, index) => ( +
+ { + const newLimits = [...rateLimits]; + newLimits[index].limit = parseInt(e.target.value, 10) || 0; + setRateLimits(newLimits); + }} + placeholder="e.g. 10" + /> +
+
+ { + const newLimits = [...rateLimits]; + newLimits[index].window = parseInt(e.target.value, 10) || 0; + setRateLimits(newLimits); + }} + placeholder="e.g. 60" + /> +
+ {rateLimits.length > 1 && ( +
+
+ ))} +
+ + Tip: You can also set dynamic rate limits in your code. + - } - cancelButton={ -
- {isOverridden && ( - - )} - - - -
- } - /> - -
+
+ + + + } + shortcut={{ modifiers: ["mod"], key: "enter" }} + > + {isOverridden ? "Update override" : "Override limit"} + + } + cancelButton={ +
+ {isOverridden && ( + + )} + + + +
+ } + /> +
); diff --git a/apps/webapp/app/v3/marqs/index.server.ts b/apps/webapp/app/v3/marqs/index.server.ts index 5348f228ae1..6111fc03ac1 100644 --- a/apps/webapp/app/v3/marqs/index.server.ts +++ b/apps/webapp/app/v3/marqs/index.server.ts @@ -186,6 +186,20 @@ export class MarQS { return this.redis.del(this.keys.queueConcurrencyLimitKey(env, queue)); } + public async updateQueueRateLimits( + env: AuthenticatedEnvironment, + queue: string, + rateLimits: Array<{ limit: number; window: number }> + ) { + // For now, we just store it in redis as JSON. The engine will need to read it. + // We need a key for rate limits. Let's assume `queueRateLimitKey` exists or we create it. + return this.redis.set(this.keys.queueRateLimitKey(env, queue), JSON.stringify(rateLimits)); + } + + public async removeQueueRateLimits(env: AuthenticatedEnvironment, queue: string) { + return this.redis.del(this.keys.queueRateLimitKey(env, queue)); + } + public async updateEnvConcurrencyLimits(env: AuthenticatedEnvironment) { const envConcurrencyLimitKey = this.keys.envConcurrencyLimitKey(env); diff --git a/apps/webapp/app/v3/marqs/marqsKeyProducer.ts b/apps/webapp/app/v3/marqs/marqsKeyProducer.ts index 5c9c7238adc..2372e82e476 100644 --- a/apps/webapp/app/v3/marqs/marqsKeyProducer.ts +++ b/apps/webapp/app/v3/marqs/marqsKeyProducer.ts @@ -42,6 +42,10 @@ export class MarQSShortKeyProducer implements MarQSKeyProducer { return [this.queueKey(env, queue), constants.CONCURRENCY_LIMIT_PART].join(":"); } + queueRateLimitKey(env: MarQSKeyProducerEnv, queue: string) { + return [this.queueKey(env, queue), "rateLimit"].join(":"); + } + envConcurrencyLimitKey(envId: string): string; envConcurrencyLimitKey(env: MarQSKeyProducerEnv): string; envConcurrencyLimitKey(envOrId: MarQSKeyProducerEnv | string): string { diff --git a/apps/webapp/app/v3/marqs/types.ts b/apps/webapp/app/v3/marqs/types.ts index 69e75ac44a5..7c5329a8a7b 100644 --- a/apps/webapp/app/v3/marqs/types.ts +++ b/apps/webapp/app/v3/marqs/types.ts @@ -18,6 +18,7 @@ export type MarQSKeyProducerEnv = { export interface MarQSKeyProducer { queueConcurrencyLimitKey(env: MarQSKeyProducerEnv, queue: string): string; + queueRateLimitKey(env: MarQSKeyProducerEnv, queue: string): string; envConcurrencyLimitKey(envId: string): string; envConcurrencyLimitKey(env: MarQSKeyProducerEnv): string; diff --git a/apps/webapp/app/v3/runQueue.server.ts b/apps/webapp/app/v3/runQueue.server.ts index e7aa13c5c54..0aff64e88d4 100644 --- a/apps/webapp/app/v3/runQueue.server.ts +++ b/apps/webapp/app/v3/runQueue.server.ts @@ -42,3 +42,26 @@ export async function removeQueueConcurrencyLimits( engine.runQueue.removeQueueConcurrencyLimits(environment, queueName), ]); } + +/** Updates MARQS and the RunQueue rate limits for a queue */ +export async function updateQueueRateLimits( + environment: AuthenticatedEnvironment, + queueName: string, + rateLimits: Array<{ limit: number; window: number }> +) { + await Promise.allSettled([ + marqs?.updateQueueRateLimits?.(environment, queueName, rateLimits), + engine.runQueue.updateQueueRateLimits?.(environment, queueName, rateLimits), + ]); +} + +/** Removes MARQS and the RunQueue rate limits for a queue */ +export async function removeQueueRateLimits( + environment: AuthenticatedEnvironment, + queueName: string +) { + await Promise.allSettled([ + marqs?.removeQueueRateLimits?.(environment, queueName), + engine.runQueue.removeQueueRateLimits?.(environment, queueName), + ]); +} diff --git a/apps/webapp/app/v3/services/rateLimitSystem.server.ts b/apps/webapp/app/v3/services/rateLimitSystem.server.ts new file mode 100644 index 00000000000..9b949768347 --- /dev/null +++ b/apps/webapp/app/v3/services/rateLimitSystem.server.ts @@ -0,0 +1,41 @@ +import { PrismaClient, Prisma } from "@trigger.dev/database"; +import { AuthenticatedEnvironment } from "~/services/apiAuth.server"; +import { removeQueueRateLimits, updateQueueRateLimits } from "../runQueue.server"; + +export class RateLimitSystem { + constructor( + private prisma: PrismaClient + ) {} + + async overrideQueueRateLimit( + environment: AuthenticatedEnvironment, + queueName: string, + rateLimits: Array<{ limit: number; window: number }> + ) { + const queue = await this.prisma.taskQueue.updateMany({ + where: { + runtimeEnvironmentId: environment.id, + name: queueName, + }, + data: { + rateLimit: rateLimits, + }, + }); + + await updateQueueRateLimits(environment, queueName, rateLimits); + } + + async resetQueueRateLimit(environment: AuthenticatedEnvironment, queueName: string) { + await this.prisma.taskQueue.updateMany({ + where: { + runtimeEnvironmentId: environment.id, + name: queueName, + }, + data: { + rateLimit: Prisma.DbNull, + }, + }); + + await removeQueueRateLimits(environment, queueName); + } +} diff --git a/apps/webapp/app/v3/services/rateLimitSystemInstance.server.ts b/apps/webapp/app/v3/services/rateLimitSystemInstance.server.ts new file mode 100644 index 00000000000..74708e39a1b --- /dev/null +++ b/apps/webapp/app/v3/services/rateLimitSystemInstance.server.ts @@ -0,0 +1,8 @@ +import { prisma } from "~/db.server"; +import { RateLimitSystem } from "./rateLimitSystem.server"; +import { singleton } from "~/utils/singleton"; + +export const rateLimitSystem = singleton( + "rateLimitSystem", + () => new RateLimitSystem(prisma) +); diff --git a/apps/webapp/test/rateLimitSystem.server.test.ts b/apps/webapp/test/rateLimitSystem.server.test.ts new file mode 100644 index 00000000000..292fefa0d10 --- /dev/null +++ b/apps/webapp/test/rateLimitSystem.server.test.ts @@ -0,0 +1,82 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { RateLimitSystem } from "../app/v3/services/rateLimitSystem.server"; +import { PrismaClient, Prisma } from "@trigger.dev/database"; +import { Redis } from "ioredis"; +import { AuthenticatedEnvironment } from "../app/services/apiAuth.server"; +import * as runQueueServer from "../app/v3/runQueue.server"; + +vi.mock("../app/v3/runQueue.server", () => ({ + updateQueueRateLimits: vi.fn(), + removeQueueRateLimits: vi.fn(), +})); + +describe("RateLimitSystem", () => { + let prismaMock: any; + let redisMock: any; + let rateLimitSystem: RateLimitSystem; + let mockEnvironment: AuthenticatedEnvironment; + + beforeEach(() => { + prismaMock = { + taskQueue: { + updateMany: vi.fn().mockResolvedValue({ count: 1 }), + }, + }; + + rateLimitSystem = new RateLimitSystem(prismaMock as unknown as PrismaClient); + + mockEnvironment = { + id: "env-123", + } as AuthenticatedEnvironment; + + vi.clearAllMocks(); + }); + + describe("overrideQueueRateLimit", () => { + it("should update the rateLimit field in the database and call the Redis sync method", async () => { + const queueName = "test-queue"; + const rateLimits = [{ limit: 10, window: 60 }]; + + await rateLimitSystem.overrideQueueRateLimit(mockEnvironment, queueName, rateLimits); + + expect(prismaMock.taskQueue.updateMany).toHaveBeenCalledWith({ + where: { + runtimeEnvironmentId: mockEnvironment.id, + name: queueName, + }, + data: { + rateLimit: rateLimits, + }, + }); + + expect(runQueueServer.updateQueueRateLimits).toHaveBeenCalledWith( + mockEnvironment, + queueName, + rateLimits + ); + }); + }); + + describe("resetQueueRateLimit", () => { + it("should clear the rateLimit field in the database and call the Redis sync method", async () => { + const queueName = "test-queue"; + + await rateLimitSystem.resetQueueRateLimit(mockEnvironment, queueName); + + expect(prismaMock.taskQueue.updateMany).toHaveBeenCalledWith({ + where: { + runtimeEnvironmentId: mockEnvironment.id, + name: queueName, + }, + data: { + rateLimit: Prisma.DbNull, + }, + }); + + expect(runQueueServer.removeQueueRateLimits).toHaveBeenCalledWith( + mockEnvironment, + queueName + ); + }); + }); +}); diff --git a/apps/webapp/test/rateLimitUI.e2e.full.test.ts b/apps/webapp/test/rateLimitUI.e2e.full.test.ts new file mode 100644 index 00000000000..98b1b1cd0a0 --- /dev/null +++ b/apps/webapp/test/rateLimitUI.e2e.full.test.ts @@ -0,0 +1,101 @@ +import { describe, expect, it } from "vitest"; +import { getTestServer } from "./helpers/sharedTestServer"; +import { seedTestSession } from "./helpers/seedTestSession"; +import { seedTestUserProject } from "./helpers/seedTestUserProject"; + +describe("Rate Limiting UI", () => { + it("should override and remove queue limits via the UI action", async () => { + const server = getTestServer(); + const { user, organization, project, environment } = await seedTestUserProject(server.prisma); + await server.prisma.user.update({ + where: { id: user.id }, + data: { confirmedBasicDetails: true }, + }); + const cookie = await seedTestSession({ userId: user.id }); + + // Get the org member + const orgMember = await server.prisma.orgMember.findFirst({ + where: { userId: user.id, organizationId: organization.id }, + }); + + // Update environment to have a high maximumConcurrencyLimit and link to orgMember + await server.prisma.runtimeEnvironment.update({ + where: { id: environment.id }, + data: { + maximumConcurrencyLimit: 100, + orgMemberId: orgMember?.id, + }, + }); + + // Create a queue + const queue = await server.prisma.taskQueue.create({ + data: { + name: "test-queue", + friendlyId: "queue_12345", + type: "NAMED", + runtimeEnvironmentId: environment.id, + projectId: project.id, + concurrencyLimit: 5, + }, + }); + + const path = `/orgs/${organization.slug}/projects/${project.slug}/env/${environment.slug}/queues`; + + // 1. Override limits + const overrideFormData = new URLSearchParams(); + overrideFormData.append("action", "queue-override"); + overrideFormData.append("friendlyId", queue.friendlyId); + overrideFormData.append("concurrencyLimit", "5"); + overrideFormData.append("rateLimits", JSON.stringify([{ limit: 10, window: 60 }])); + + const overrideRes = await server.webapp.fetch(path, { + method: "POST", + body: overrideFormData.toString(), + headers: { + "Content-Type": "application/x-www-form-urlencoded", + Cookie: cookie, + }, + redirect: "manual", + }); + + expect(overrideRes.status).toBe(302); + const location = overrideRes.headers.get("location"); + if (location?.includes("error")) { + throw new Error(`Redirected with error: ${location}`); + } + + // Verify database + const updatedQueue = await server.prisma.taskQueue.findUnique({ + where: { id: queue.id }, + }); + + expect(updatedQueue?.concurrencyLimit).toBe(5); + expect(updatedQueue?.rateLimit).toEqual([{ limit: 10, window: 60 }]); + + // 2. Remove override + const removeFormData = new URLSearchParams(); + removeFormData.append("action", "queue-remove-override"); + removeFormData.append("friendlyId", queue.friendlyId); + + const removeRes = await server.webapp.fetch(path, { + method: "POST", + body: removeFormData.toString(), + headers: { + "Content-Type": "application/x-www-form-urlencoded", + Cookie: cookie, + }, + redirect: "manual", + }); + + expect(removeRes.status).toBe(302); + + // Verify database + const resetQueue = await server.prisma.taskQueue.findUnique({ + where: { id: queue.id }, + }); + + // Concurrency limit is reset to base (which was 5) + expect(resetQueue?.concurrencyLimit).toBe(5); + expect(resetQueue?.rateLimit).toBe(null); + }); +}); diff --git a/docs/docs.json b/docs/docs.json index 24c0339f3ed..e9eb09ceead 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -82,6 +82,7 @@ ] }, "queue-concurrency", + "rate-limiting", "versioning", "machines", "idempotency", diff --git a/docs/introduction.mdx b/docs/introduction.mdx index 383040de145..1d2bf1de1d4 100644 --- a/docs/introduction.mdx +++ b/docs/introduction.mdx @@ -74,6 +74,9 @@ We provide everything you need to build and manage background tasks: a CLI and S Configure what you want to happen when there is more than one run at a time. + + Control how many runs can execute within a specific time window. + { + //... + }, +}); +``` + +## Dynamic rate limits + +A dynamic rate limit applies per-tenant or per-user, based on the payload. You define it using a `dynamicKey` which is a JSON path to a value in your payload. + +```ts /trigger/dynamic-rate-limited.ts +export const dynamicRateLimitedTask = task({ + id: "dynamic-rate-limited-task", + rateLimits: [ + { + dynamicKey: "payload.userId", + limit: 10, + window: "1m", // 10 runs per minute per user + } + ], + run: async (payload: { userId: string }, { ctx }) => { + //... + }, +}); +``` + +## Custom queues + +You can also apply rate limits to custom queues, which allows multiple tasks to share the same rate limit: + +```ts /trigger/rate-limited-queue.ts +export const apiQueue = queue({ + name: "api-queue", +}); + +export const task1 = task({ + id: "task-1", + queue: apiQueue, + rateLimits: [ + { + staticKey: "shared-api", + limit: 50, + window: 10, // 50 runs per 10 seconds + }, + ], + run: async (payload) => { + // ... + }, +}); +``` + +## Overriding rate limits + +You can override queue rate limits dynamically from the Trigger.dev dashboard. Navigate to the **Queues** page in your project, select a queue, and use the UI to add, modify, or remove rate limits. diff --git a/docs/tasks/overview.mdx b/docs/tasks/overview.mdx index d7cb4d8f4d9..83458b2a75c 100644 --- a/docs/tasks/overview.mdx +++ b/docs/tasks/overview.mdx @@ -100,7 +100,7 @@ It's also worth mentioning that you can [retry a block of code](/errors-retrying ### `queue` options -Queues allow you to control the concurrency of your tasks. This allows you to have one-at-a-time execution and parallel executions. There are also more advanced techniques like having different concurrencies for different sets of your users. For more information read [the concurrency & queues guide](/queue-concurrency). +Queues allow you to control the concurrency of your tasks. This allows you to have one-at-a-time execution and parallel executions. There are also more advanced techniques like having different concurrencies for different sets of your users. For more information read [the concurrency, queues & rate limiting guide](/queue-concurrency). ```ts /trigger/one-at-a-time.ts export const oneAtATime = task({ @@ -114,6 +114,31 @@ export const oneAtATime = task({ }); ``` +### `rateLimits` options + +Rate limits allow you to control how many runs can execute within a specific time window. You can define both static and dynamic rate limits. For more information read [the rate limiting guide](/rate-limiting). + +```ts /trigger/rate-limited.ts +export const rateLimitedTask = task({ + id: "rate-limited-task", + rateLimits: [ + { + staticKey: "my-api", + limit: 100, + window: 60, // 100 runs per 60 seconds + }, + { + dynamicKey: "payload.userId", + limit: 10, + window: 60, // 10 runs per 60 seconds per user + } + ], + run: async (payload: any, { ctx }) => { + //... + }, +}); +``` + ### `machine` options Some tasks require more vCPUs or GBs of RAM. You can specify these requirements in the `machine` field. For more information read [the machines guide](/machines). diff --git a/docs/writing-tasks-introduction.mdx b/docs/writing-tasks-introduction.mdx index 5f0bc330912..e7d62c7e417 100644 --- a/docs/writing-tasks-introduction.mdx +++ b/docs/writing-tasks-introduction.mdx @@ -16,6 +16,7 @@ Before digging deeper into the details of writing tasks, you should read the [fu | [Errors & retrying](/errors-retrying) | How to deal with errors and write reliable tasks. | | [Wait](/wait) | Wait for periods of time or for external events to occur before continuing. | | [Concurrency & Queues](/queue-concurrency) | Configure what you want to happen when there is more than one run at a time. | +| [Rate Limiting](/rate-limiting) | Control how many runs can execute within a specific time window. | | [Realtime notifications](/realtime/overview) | Send realtime notifications from your task that you can subscribe to from your backend or frontend. | | [Versioning](/versioning) | How versioning works. | | [Machines](/machines) | Configure the CPU and RAM of the machine your task runs on | diff --git a/packages/core/src/v3/schemas/queues.ts b/packages/core/src/v3/schemas/queues.ts index 34a47b34e3e..9b8848f7c5f 100644 --- a/packages/core/src/v3/schemas/queues.ts +++ b/packages/core/src/v3/schemas/queues.ts @@ -47,6 +47,16 @@ export const QueueItem = z.object({ overriddenBy: z.string().nullable(), }) .optional(), + /** The rate limits of the queue */ + rateLimit: z + .array( + z.object({ + limit: z.number(), + window: z.number(), + }) + ) + .nullable() + .optional(), }); export type QueueItem = z.infer; diff --git a/packages/core/src/v3/schemas/schemas.ts b/packages/core/src/v3/schemas/schemas.ts index 95564cb1efc..9081411baff 100644 --- a/packages/core/src/v3/schemas/schemas.ts +++ b/packages/core/src/v3/schemas/schemas.ts @@ -184,10 +184,21 @@ const AgentConfig = z.object({ type: z.string(), }); +export const RateLimitConfig = z.object({ + staticKey: z.string().optional(), + dynamicKey: z.string().optional(), + limit: z.number().int().positive().optional(), + window: z.union([z.string(), z.number().int().positive()]).optional(), + units: z.number().int().positive().optional(), +}); + +export type RateLimitConfig = z.infer; + const taskMetadata = { id: z.string(), description: z.string().optional(), queue: QueueManifest.extend({ name: z.string().optional() }).optional(), + rateLimits: z.array(RateLimitConfig).optional(), retry: RetryOptions.optional(), machine: MachineConfig.optional(), triggerSource: z.string().optional(), diff --git a/packages/core/src/v3/types/tasks.ts b/packages/core/src/v3/types/tasks.ts index 978a6e5bd0a..8d82c4323fe 100644 --- a/packages/core/src/v3/types/tasks.ts +++ b/packages/core/src/v3/types/tasks.ts @@ -226,6 +226,39 @@ type CommonTaskOptions< name?: string; concurrencyLimit?: number; }; + + /** + * Rate limits for the task. + * + * @example + * ```ts + * export const myTask = task({ + * id: "my-task", + * rateLimits: [ + * { + * staticKey: "my-api", + * units: 1, + * }, + * { + * dynamicKey: "payload.userId", + * limit: 10, + * window: "1m", + * } + * ], + * run: async ({ payload, ctx }) => { + * //... + * }, + * }); + * ``` + */ + rateLimits?: Array<{ + staticKey?: string; + dynamicKey?: string; + limit?: number; + window?: string | number; + units?: number; + }>; + /** Configure the spec of the [machine](https://trigger.dev/docs/machines) you want your task to run on. * * @example diff --git a/packages/redis-worker/src/fair-queue/index.ts b/packages/redis-worker/src/fair-queue/index.ts index 0c0b7921a6b..3d915f965e8 100644 --- a/packages/redis-worker/src/fair-queue/index.ts +++ b/packages/redis-worker/src/fair-queue/index.ts @@ -5,6 +5,7 @@ import { nanoid } from "nanoid"; import { setInterval } from "node:timers/promises"; import { type z } from "zod"; import { ConcurrencyManager } from "./concurrency.js"; +import { RateLimitManager } from "./rateLimit.js"; import { MasterQueue } from "./masterQueue.js"; import { TenantDispatch } from "./tenantDispatch.js"; import { type RetryStrategy, ExponentialBackoffRetry } from "./retry.js"; @@ -38,6 +39,7 @@ export * from "./types.js"; export * from "./keyProducer.js"; export * from "./masterQueue.js"; export * from "./concurrency.js"; +export * from "./rateLimit.js"; export * from "./visibility.js"; export * from "./workerQueue.js"; export * from "./scheduler.js"; @@ -70,6 +72,7 @@ export class FairQueue { private scheduler: FairScheduler; private masterQueue: MasterQueue; private concurrencyManager?: ConcurrencyManager; + private rateLimitManager: RateLimitManager; private visibilityManager: VisibilityManager; private workerQueueManager: WorkerQueueManager; private telemetry: FairQueueTelemetry; @@ -201,6 +204,11 @@ export class FairQueue { }); } + this.rateLimitManager = new RateLimitManager({ + redis: options.redis, + keys: options.keys, + }); + this.visibilityManager = new VisibilityManager({ redis: options.redis, keys: options.keys, @@ -311,9 +319,11 @@ export class FairQueue { timestamp, attempt: 1, metadata: options.metadata, + rateLimits: options.rateLimits, }) : undefined, metadata: options.metadata, + rateLimits: options.rateLimits, }; // Use atomic Lua script to enqueue and update tenant dispatch indexes @@ -410,9 +420,11 @@ export class FairQueue { timestamp, attempt: 1, metadata: options.metadata, + rateLimits: message.rateLimits, }) : undefined, metadata: options.metadata, + rateLimits: message.rateLimits, }; messageIds.push(messageId); @@ -698,6 +710,7 @@ export class FairQueue { this.masterQueue.close(), this.tenantDispatch.close(), this.concurrencyManager?.close(), + this.rateLimitManager.close(), this.visibilityManager.close(), this.workerQueueManager.close(), this.scheduler.close?.(), @@ -1143,6 +1156,36 @@ export class FairQueue { for (let i = 0; i < claimedMessages.length; i++) { const message = claimedMessages[i]!; + // Check rate limits + if (message.payload.rateLimits && message.payload.rateLimits.length > 0) { + const rateLimitResult = await this.rateLimitManager.checkAndConsume( + message.payload.rateLimits + ); + + if (!rateLimitResult.allowed) { + // Rate limit exceeded, delay the message + const resetAt = rateLimitResult.resetAt ?? Date.now() + 1000; // Fallback to 1s delay + + // Release this message back to the queue with a delayed timestamp + const tenantQueueIndexKey = this.keys.tenantQueueIndexKey(tenantId); + const dispatchKey = this.keys.dispatchKey(dispatchShardId); + + await this.visibilityManager.releaseBatch( + [message], + queueId, + queueKey, + queueItemsKey, + tenantQueueIndexKey, + dispatchKey, + tenantId, + resetAt + ); + + // Continue processing other messages in the batch, as they might have different dynamic keys + continue; + } + } + // Reserve concurrency slot if (this.concurrencyManager) { const reserved = await this.concurrencyManager.reserve(descriptor, message.messageId); diff --git a/packages/redis-worker/src/fair-queue/rateLimit.ts b/packages/redis-worker/src/fair-queue/rateLimit.ts new file mode 100644 index 00000000000..701d0f565b6 --- /dev/null +++ b/packages/redis-worker/src/fair-queue/rateLimit.ts @@ -0,0 +1,190 @@ +import { createRedisClient, type Redis, type RedisOptions } from "@internal/redis"; +import type { FairQueueKeyProducer, RateLimitRequest, RateLimitCheckResult } from "./types.js"; + +export interface RateLimitManagerOptions { + redis: RedisOptions; + keys: FairQueueKeyProducer; +} + +export class RateLimitManager { + private redis: Redis; + private keys: FairQueueKeyProducer; + + constructor(options: RateLimitManagerOptions) { + this.redis = createRedisClient(options.redis); + this.keys = options.keys; + + this.#registerCommands(); + } + + /** + * Upsert a static rate limit configuration. + */ + async upsertStaticConfig(key: string, limit: number, windowMs: number): Promise { + const configKey = `rate_limit_config:${key}`; + await this.redis.hset(configKey, { + limit: limit.toString(), + windowMs: windowMs.toString(), + }); + } + + /** + * Get multiple static rate limit configurations. + */ + async getStaticConfigs(keys: string[]): Promise> { + if (keys.length === 0) { + return new Map(); + } + + const pipeline = this.redis.pipeline(); + for (const key of keys) { + pipeline.hgetall(`rate_limit_config:${key}`); + } + + const results = await pipeline.exec(); + const map = new Map(); + + if (!results) { + return map; + } + + for (let i = 0; i < keys.length; i++) { + const key = keys[i]!; + const [err, result] = results[i]!; + + if (err || !result || Object.keys(result).length === 0) { + map.set(key, null); + } else { + const res = result as Record; + map.set(key, { + limit: parseInt(res.limit!, 10), + windowMs: parseInt(res.windowMs!, 10), + }); + } + } + + return map; + } + + /** + * Check and consume rate limits atomically. + */ + async checkAndConsume(requests: RateLimitRequest[]): Promise { + if (requests.length === 0) { + return { allowed: true }; + } + + const now = Date.now(); + const keys: string[] = []; + const args: string[] = [now.toString()]; + + // Fetch all static configs in parallel + const staticRequests = requests.filter((r) => r.isStatic); + let staticConfigs = new Map(); + + if (staticRequests.length > 0) { + staticConfigs = await this.getStaticConfigs(staticRequests.map((r) => r.key)); + } + + for (const req of requests) { + let limit = req.limit; + let windowMs = req.windowMs; + + if (req.isStatic) { + const config = staticConfigs.get(req.key); + if (!config) { + // If static config is missing, we reject safely + return { allowed: false, resetAt: now + 60000 }; // Fallback delay + } + limit = config.limit; + windowMs = config.windowMs; + } + + if (limit === undefined || windowMs === undefined) { + throw new Error(`Rate limit configuration missing for key: ${req.key}`); + } + + if (limit === 0) { + return { allowed: false, resetAt: now + windowMs }; + } + + // Calculate the current window start time + const windowStart = Math.floor(now / windowMs) * windowMs; + const redisKey = `rate_limit:${req.key}:${windowStart}`; + + keys.push(redisKey); + args.push(limit.toString(), req.units.toString(), windowMs.toString()); + } + + // Execute the Lua script + // The script returns [allowed (1 or 0), resetAt (if not allowed)] + const result = await this.redis.consumeRateLimit(keys.length, keys, ...args); + + if (result[0] === 1) { + return { allowed: true }; + } else { + return { allowed: false, resetAt: result[1] }; + } + } + + async close(): Promise { + await this.redis.quit(); + } + + #registerCommands(): void { + // Lua script for atomic multi-key rate limiting + // KEYS: array of rate limit keys for the current window + // ARGV: [now, limit1, units1, windowMs1, limit2, units2, windowMs2, ...] + this.redis.defineCommand("consumeRateLimit", { + lua: ` +local numRequests = #KEYS +local now = tonumber(ARGV[1]) + +-- Step 1: Check all limits +for i = 1, numRequests do + local key = KEYS[i] + local limit = tonumber(ARGV[(i - 1) * 3 + 2]) + local units = tonumber(ARGV[(i - 1) * 3 + 3]) + local windowMs = tonumber(ARGV[(i - 1) * 3 + 4]) + + local current = tonumber(redis.call('GET', key) or "0") + + if current + units > limit then + local ttl = redis.call('PTTL', key) + local resetAt + if ttl > 0 then + resetAt = now + ttl + else + resetAt = now + windowMs + end + return {0, resetAt} + end +end + +-- Step 2: Consume units for all keys +for i = 1, numRequests do + local key = KEYS[i] + local units = tonumber(ARGV[(i - 1) * 3 + 3]) + local windowMs = tonumber(ARGV[(i - 1) * 3 + 4]) + + local current = redis.call('INCRBY', key, units) + if current == units then + redis.call('PEXPIRE', key, windowMs) + end +end + +return {1, 0} + `, + }); + } +} + +declare module "@internal/redis" { + interface RedisCommander { + consumeRateLimit( + numKeys: number, + keys: string[], + ...args: string[] + ): Promise<[number, number]>; + } +} diff --git a/packages/redis-worker/src/fair-queue/tests/fairQueue.test.ts b/packages/redis-worker/src/fair-queue/tests/fairQueue.test.ts index f1634ad2fe0..aec9557d5b9 100644 --- a/packages/redis-worker/src/fair-queue/tests/fairQueue.test.ts +++ b/packages/redis-worker/src/fair-queue/tests/fairQueue.test.ts @@ -1372,4 +1372,160 @@ describe("FairQueue", () => { ); }); + describe("rate limiting", () => { + redisTest( + "should delay task processing until resetAt when a rate limit is hit", + { timeout: 15000 }, + async ({ redisOptions }) => { + const processed: string[] = []; + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + + const scheduler = new DRRScheduler({ + redis: redisOptions, + keys, + quantum: 10, + maxDeficit: 100, + }); + + const queue = new TestFairQueueHelper(redisOptions, keys, { + scheduler, + payloadSchema: TestPayloadSchema, + shardCount: 1, + consumerCount: 1, + consumerIntervalMs: 50, + visibilityTimeoutMs: 5000, + startConsumers: false, + }); + + queue.onMessage(async (ctx) => { + processed.push(ctx.message.payload.value); + await ctx.complete(); + }); + + // Enqueue message with a rate limit of 1 + await queue.enqueue({ + queueId: "tenant:t1:queue:q1", + tenantId: "t1", + payload: { value: "first-allowed" }, + rateLimits: [ + { + key: "test-limit-delay", + limit: 1, + windowMs: 2000, + units: 1, + }, + ], + }); + + // Enqueue second message which should hit the limit + await queue.enqueue({ + queueId: "tenant:t1:queue:q1", + tenantId: "t1", + payload: { value: "second-delayed" }, + rateLimits: [ + { + key: "test-limit-delay", + limit: 1, + windowMs: 2000, + units: 1, + }, + ], + }); + + // Start processing + queue.start(); + + // Wait for first message to be processed + await vi.waitFor( + () => { + expect(processed).toContain("first-allowed"); + }, + { timeout: 5000 } + ); + + // Second message should NOT be processed yet + expect(processed).not.toContain("second-delayed"); + + // Wait for the window to expire (2000ms) + await vi.waitFor( + () => { + expect(processed).toContain("second-delayed"); + }, + { timeout: 5000 } + ); + + await queue.close(); + } + ); + + redisTest( + "should not block other tenants when one tenant hits a rate limit", + { timeout: 15000 }, + async ({ redisOptions }) => { + const processed: string[] = []; + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + + const scheduler = new DRRScheduler({ + redis: redisOptions, + keys, + quantum: 10, + maxDeficit: 100, + }); + + const queue = new TestFairQueueHelper(redisOptions, keys, { + scheduler, + payloadSchema: TestPayloadSchema, + shardCount: 1, + consumerCount: 1, + consumerIntervalMs: 50, + visibilityTimeoutMs: 5000, + startConsumers: false, + }); + + queue.onMessage(async (ctx) => { + processed.push(ctx.message.payload.value); + await ctx.complete(); + }); + + // Tenant 1 hits rate limit + await queue.enqueue({ + queueId: "tenant:t1:queue:q1", + tenantId: "t1", + payload: { value: "t1-limited" }, + rateLimits: [ + { + key: "t1-limit", + limit: 0, + windowMs: 5000, + units: 1, + }, + ], + }); + + // Tenant 2 has no rate limit + await queue.enqueue({ + queueId: "tenant:t2:queue:q1", + tenantId: "t2", + payload: { value: "t2-allowed" }, + }); + + // Start processing + queue.start(); + + // Wait for t2 to be processed + await vi.waitFor( + () => { + expect(processed).toContain("t2-allowed"); + }, + { timeout: 5000 } + ); + + // t1 should still not be processed + expect(processed).not.toContain("t1-limited"); + + await queue.close(); + } + ); + }); + }); diff --git a/packages/redis-worker/src/fair-queue/tests/rateLimit.test.ts b/packages/redis-worker/src/fair-queue/tests/rateLimit.test.ts new file mode 100644 index 00000000000..323c32b53ea --- /dev/null +++ b/packages/redis-worker/src/fair-queue/tests/rateLimit.test.ts @@ -0,0 +1,247 @@ +import { describe, expect } from "vitest"; +import { redisTest } from "@internal/testcontainers"; +import { RateLimitManager } from "../rateLimit.js"; +import { DefaultFairQueueKeyProducer } from "../keyProducer.js"; +import type { FairQueueKeyProducer } from "../types.js"; + +describe("RateLimitManager", () => { + let keys: FairQueueKeyProducer; + + describe("unit tests", () => { + redisTest( + "should allow consumption when requesting units within the defined limit", + { timeout: 10000 }, + async ({ redisOptions }) => { + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + const manager = new RateLimitManager({ redis: redisOptions, keys }); + + const result = await manager.checkAndConsume([ + { key: "test-key-1", limit: 10, windowMs: 1000, units: 1 }, + ]); + + expect(result.allowed).toBe(true); + expect(result.resetAt).toBeUndefined(); + + await manager.close(); + } + ); + + redisTest( + "should reject consumption and return resetAt when requesting units exceeding the limit", + { timeout: 10000 }, + async ({ redisOptions }) => { + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + const manager = new RateLimitManager({ redis: redisOptions, keys }); + + const result = await manager.checkAndConsume([ + { key: "test-key-2", limit: 10, windowMs: 1000, units: 11 }, + ]); + + expect(result.allowed).toBe(false); + expect(result.resetAt).toBeDefined(); + expect(result.resetAt).toBeGreaterThan(Date.now()); + + await manager.close(); + } + ); + + redisTest( + "should atomically evaluate multiple keys and allow if all have capacity", + { timeout: 10000 }, + async ({ redisOptions }) => { + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + const manager = new RateLimitManager({ redis: redisOptions, keys }); + + const result = await manager.checkAndConsume([ + { key: "test-key-3a", limit: 10, windowMs: 1000, units: 1 }, + { key: "test-key-3b", limit: 5, windowMs: 1000, units: 1 }, + ]); + + expect(result.allowed).toBe(true); + + await manager.close(); + } + ); + + redisTest( + "should atomically reject and consume zero units if any key exceeds its limit", + { timeout: 10000 }, + async ({ redisOptions }) => { + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + const manager = new RateLimitManager({ redis: redisOptions, keys }); + + const result = await manager.checkAndConsume([ + { key: "test-key-4a", limit: 10, windowMs: 1000, units: 1 }, + { key: "test-key-4b", limit: 5, windowMs: 1000, units: 6 }, + ]); + + expect(result.allowed).toBe(false); + + // Verify that test-key-4a was NOT consumed + const checkResult = await manager.checkAndConsume([ + { key: "test-key-4a", limit: 10, windowMs: 1000, units: 10 }, + ]); + expect(checkResult.allowed).toBe(true); + + await manager.close(); + } + ); + + redisTest( + "should reset the window and allow consumption after the window duration has passed", + { timeout: 10000 }, + async ({ redisOptions }) => { + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + const manager = new RateLimitManager({ redis: redisOptions, keys }); + + // Consume full quota + const result1 = await manager.checkAndConsume([ + { key: "test-key-5", limit: 1, windowMs: 100, units: 1 }, + ]); + expect(result1.allowed).toBe(true); + + // Should be rejected immediately after + const result2 = await manager.checkAndConsume([ + { key: "test-key-5", limit: 1, windowMs: 100, units: 1 }, + ]); + expect(result2.allowed).toBe(false); + + // Wait for window to expire + await new Promise((resolve) => setTimeout(resolve, 150)); + + // Should be allowed again + const result3 = await manager.checkAndConsume([ + { key: "test-key-5", limit: 1, windowMs: 100, units: 1 }, + ]); + expect(result3.allowed).toBe(true); + + await manager.close(); + } + ); + + redisTest( + "should handle high concurrency without race conditions or exceeding limits", + { timeout: 10000 }, + async ({ redisOptions }) => { + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + const manager = new RateLimitManager({ redis: redisOptions, keys }); + + const promises = Array.from({ length: 100 }, () => + manager.checkAndConsume([ + { key: "test-key-6", limit: 50, windowMs: 5000, units: 1 }, + ]) + ); + + const results = await Promise.all(promises); + + const allowedCount = results.filter((r) => r.allowed).length; + const rejectedCount = results.filter((r) => !r.allowed).length; + + expect(allowedCount).toBe(50); + expect(rejectedCount).toBe(50); + + await manager.close(); + } + ); + + redisTest( + "should enforce the limit based on the current request definition for dynamic limits", + { timeout: 10000 }, + async ({ redisOptions }) => { + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + const manager = new RateLimitManager({ redis: redisOptions, keys }); + + // First request with limit 10 + const result1 = await manager.checkAndConsume([ + { key: "test-key-7", limit: 10, windowMs: 1000, units: 5 }, + ]); + expect(result1.allowed).toBe(true); + + // Second request with limit 5 (should fail because 5 units already consumed) + const result2 = await manager.checkAndConsume([ + { key: "test-key-7", limit: 5, windowMs: 1000, units: 1 }, + ]); + expect(result2.allowed).toBe(false); + + await manager.close(); + } + ); + + redisTest( + "should correctly store and retrieve static rate limit configurations", + { timeout: 10000 }, + async ({ redisOptions }) => { + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + const manager = new RateLimitManager({ redis: redisOptions, keys }); + + await manager.upsertStaticConfig("static-key-1", 100, 60000); + + const configs = await manager.getStaticConfigs(["static-key-1"]); + expect(configs.get("static-key-1")).toEqual({ limit: 100, windowMs: 60000 }); + + await manager.close(); + } + ); + + redisTest( + "should safely reject consumption when a static key has not been configured", + { timeout: 10000 }, + async ({ redisOptions }) => { + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + const manager = new RateLimitManager({ redis: redisOptions, keys }); + + const configs = await manager.getStaticConfigs(["non-existent-key"]); + expect(configs.get("non-existent-key")).toBeNull(); + + await manager.close(); + } + ); + + redisTest( + "should always reject consumption when the limit is zero", + { timeout: 10000 }, + async ({ redisOptions }) => { + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + const manager = new RateLimitManager({ redis: redisOptions, keys }); + + const result = await manager.checkAndConsume([ + { key: "test-key-10", limit: 0, windowMs: 1000, units: 1 }, + ]); + + expect(result.allowed).toBe(false); + expect(result.resetAt).toBeDefined(); + + await manager.close(); + } + ); + + redisTest( + "should set a TTL on Redis keys to prevent memory leaks", + { timeout: 10000 }, + async ({ redisOptions }) => { + keys = new DefaultFairQueueKeyProducer({ prefix: "test" }); + const manager = new RateLimitManager({ redis: redisOptions, keys }); + + const now = Date.now(); + const windowMs = 5000; + const windowStart = Math.floor(now / windowMs) * windowMs; + const redisKey = `rate_limit:test-key-11:${windowStart}`; + + await manager.checkAndConsume([ + { key: "test-key-11", limit: 10, windowMs, units: 1 }, + ]); + + // We need to create a separate redis client to check PTTL + const { createRedisClient } = await import("@internal/redis"); + const redis = createRedisClient(redisOptions); + + const pttl = await redis.pttl(redisKey); + expect(pttl).toBeGreaterThan(0); + expect(pttl).toBeLessThanOrEqual(windowMs); + + await redis.quit(); + await manager.close(); + } + ); + }); +}); diff --git a/packages/redis-worker/src/fair-queue/types.ts b/packages/redis-worker/src/fair-queue/types.ts index 3bc56b599fa..f36693f6074 100644 --- a/packages/redis-worker/src/fair-queue/types.ts +++ b/packages/redis-worker/src/fair-queue/types.ts @@ -76,6 +76,8 @@ export interface StoredMessage { workerQueue?: string; /** Additional metadata */ metadata?: Record; + /** Rate limits to enforce before processing */ + rateLimits?: RateLimitRequest[]; } /** @@ -133,6 +135,36 @@ export interface ConcurrencyCheckResult { blockedBy?: ConcurrencyState; } +// ============================================================================ +// Rate Limiting Types +// ============================================================================ + +/** + * Request to consume units from a rate limit. + */ +export interface RateLimitRequest { + /** The unique key for this rate limit (e.g., "api-service" or "user-123") */ + key: string; + /** The maximum number of units allowed in the window. Optional for static limits. */ + limit?: number; + /** The duration of the window in milliseconds. Optional for static limits. */ + windowMs?: number; + /** The number of units to consume (default: 1) */ + units: number; + /** Whether this is a static rate limit that requires fetching config from Redis */ + isStatic?: boolean; +} + +/** + * Result of a rate limit check. + */ +export interface RateLimitCheckResult { + /** Whether the request is allowed */ + allowed: boolean; + /** If not allowed, the timestamp (ms since epoch) when the limit resets */ + resetAt?: number; +} + // ============================================================================ // Visibility Types // ============================================================================ @@ -553,6 +585,8 @@ export interface EnqueueOptions { timestamp?: number; /** Optional metadata for concurrency group extraction */ metadata?: Record; + /** Rate limits to enforce before processing */ + rateLimits?: RateLimitRequest[]; } /** @@ -568,6 +602,7 @@ export interface EnqueueBatchOptions { payload: TPayload; messageId?: string; timestamp?: number; + rateLimits?: RateLimitRequest[]; }>; /** Optional metadata for concurrency group extraction */ metadata?: Record; diff --git a/packages/trigger-sdk/src/v3/shared.ts b/packages/trigger-sdk/src/v3/shared.ts index fba990949b3..79a4c7d6c77 100644 --- a/packages/trigger-sdk/src/v3/shared.ts +++ b/packages/trigger-sdk/src/v3/shared.ts @@ -264,6 +264,7 @@ export function createTask< id: params.id, description: params.description, queue: params.queue, + rateLimits: params.rateLimits, retry: params.retry ? { ...defaultRetryOptions, ...params.retry } : undefined, machine: typeof params.machine === "string" ? { preset: params.machine } : params.machine, triggerSource: params.triggerSource, @@ -418,6 +419,7 @@ export function createSchemaTask< id: params.id, description: params.description, queue: params.queue, + rateLimits: params.rateLimits, retry: params.retry ? { ...defaultRetryOptions, ...params.retry } : undefined, machine: typeof params.machine === "string" ? { preset: params.machine } : params.machine, triggerSource: params.triggerSource,