|
1 | 1 | import { openai } from '@ai-sdk/openai' |
| 2 | +import { Ratelimit } from '@upstash/ratelimit' |
| 3 | +import { kv } from '@vercel/kv' |
2 | 4 | import { ToolInvocation, convertToCoreMessages, streamText } from 'ai' |
3 | 5 | import { codeBlock } from 'common-tags' |
4 | 6 | import { convertToCoreTools, maxMessageContext, maxRowLimit, tools } from '~/lib/tools' |
| 7 | +import { createClient } from '~/utils/supabase/server' |
5 | 8 |
|
6 | 9 | // Allow streaming responses up to 30 seconds |
7 | 10 | export const maxDuration = 30 |
8 | 11 |
|
| 12 | +const inputTokenRateLimit = new Ratelimit({ |
| 13 | + redis: kv, |
| 14 | + limiter: Ratelimit.fixedWindow(1000000, '30m'), |
| 15 | + prefix: 'ratelimit:tokens:input', |
| 16 | +}) |
| 17 | + |
| 18 | +const outputTokenRateLimit = new Ratelimit({ |
| 19 | + redis: kv, |
| 20 | + limiter: Ratelimit.fixedWindow(10000, '30m'), |
| 21 | + prefix: 'ratelimit:tokens:output', |
| 22 | +}) |
| 23 | + |
9 | 24 | type Message = { |
10 | 25 | role: 'user' | 'assistant' |
11 | 26 | content: string |
12 | 27 | toolInvocations?: (ToolInvocation & { result: any })[] |
13 | 28 | } |
14 | 29 |
|
15 | 30 | export async function POST(req: Request) { |
| 31 | + const supabase = createClient() |
| 32 | + |
| 33 | + const { data, error } = await supabase.auth.getUser() |
| 34 | + |
| 35 | + // We have middleware, so this should never happen (used for type narrowing) |
| 36 | + if (error) { |
| 37 | + return new Response('Unauthorized', { status: 401 }) |
| 38 | + } |
| 39 | + |
| 40 | + const { user } = data |
| 41 | + |
| 42 | + const { remaining: inputRemaining } = await inputTokenRateLimit.getRemaining(user.id) |
| 43 | + const { remaining: outputRemaining } = await outputTokenRateLimit.getRemaining(user.id) |
| 44 | + |
| 45 | + if (inputRemaining <= 0 || outputRemaining <= 0) { |
| 46 | + return new Response('Rate limited', { status: 429 }) |
| 47 | + } |
| 48 | + |
16 | 49 | const { messages }: { messages: Message[] } = await req.json() |
17 | 50 |
|
18 | 51 | // Trim the message context sent to the LLM to mitigate token abuse |
@@ -64,6 +97,14 @@ export async function POST(req: Request) { |
64 | 97 | model: openai('gpt-4o-2024-08-06'), |
65 | 98 | messages: convertToCoreMessages(trimmedMessageContext), |
66 | 99 | tools: convertToCoreTools(tools), |
| 100 | + async onFinish({ usage }) { |
| 101 | + await inputTokenRateLimit.limit(user.id, { |
| 102 | + rate: usage.promptTokens, |
| 103 | + }) |
| 104 | + await outputTokenRateLimit.limit(user.id, { |
| 105 | + rate: usage.completionTokens, |
| 106 | + }) |
| 107 | + }, |
67 | 108 | }) |
68 | 109 |
|
69 | 110 | return result.toAIStreamResponse() |
|
0 commit comments