Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions playground/app/pages/ai.vue
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const { messages, input, handleSubmit, isLoading, stop, error, reload } = useCha
class="p-2 mt-1 text-sm rounded-lg text-smp-2 whitespace-pre-line"
:class="message.role === 'assistant' ? 'text-white bg-blue-400' : 'text-gray-700 bg-gray-200'"
:value="message.content"
:cache-key="message.id"
/>
</div>
</div>
Expand Down
31 changes: 20 additions & 11 deletions playground/server/api/chat.post.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,29 @@ defineRouteMeta({
export default defineEventHandler(async (event) => {
const { messages } = await readBody(event)

const workersAI = createWorkersAI({ binding: hubAI() })
const workersAI = createWorkersAI({
binding: hubAI(),
gateway: {
id: 'playground',
cacheTtl: 60 * 60 // 1 hour
}
})

// return hubAI().run('@cf/meta/llama-3.1-8b-instruct', {
// messages
// }, {
// gateway: {
// id: 'playground'
// }
// })

return streamText({
model: workersAI('@cf/meta/llama-3.1-8b-instruct'),
messages
}).toDataStreamResponse({
// headers: {
// // add these headers to ensure that the
// // response is chunked and streamed
// 'content-type': 'text/x-unknown',
// 'content-encoding': 'identity',
// 'transfer-encoding': 'chunked'
// }
})
messages,
onError(res) {
console.error(res.error)
}
}).toDataStreamResponse()

// For testing purposes, we'll randomly throw an error
// if (Math.round(Math.random()) === 1) {
Expand Down
7 changes: 4 additions & 3 deletions src/runtime/ai/server/api/_hub/ai/[command].post.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import { requireNuxtHubFeature } from '../../../../../utils/features'

const statementValidation = z.object({
model: z.string().min(1).max(1e6).trim(),
params: z.record(z.string(), z.any()).optional()
params: z.record(z.string(), z.any()).optional(),
options: z.record(z.string(), z.any()).optional()
})

export default eventHandler(async (event) => {
Expand All @@ -20,9 +21,9 @@ export default eventHandler(async (event) => {
const ai = hubAI()

if (command === 'run') {
const { model, params } = await readValidatedBody(event, statementValidation.pick({ model: true, params: true }).parse)
const { model, params, options } = await readValidatedBody(event, statementValidation.pick({ model: true, params: true, options: true }).parse)
// @ts-expect-error Ai type defines all the compatible models, however Zod is only validating for string
const res = await ai.run(model, params)
const res = await ai.run(model, params, options)

// Image generation returns a ReadableStream
if (res instanceof ReadableStream) {
Expand Down
19 changes: 10 additions & 9 deletions src/runtime/ai/server/utils/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { ofetch } from 'ofetch'
import { joinURL } from 'ufo'
import { createError } from 'h3'
import type { H3Error } from 'h3'
import type { Ai } from '@cloudflare/workers-types/experimental'
import type { Ai, AiOptions } from '@cloudflare/workers-types/experimental'
import { requireNuxtHubFeature } from '../../../utils/features'
import { getCloudflareAccessHeaders } from '../../../utils/cloudflareAccess'
import { useRuntimeConfig } from '#imports'
Expand Down Expand Up @@ -36,7 +36,7 @@ export function hubAI(): Ai {
} else if (import.meta.dev) {
// Mock _ai to call NuxtHub Admin API to proxy CF account & API token
_ai = {
async run(model: string, params?: Record<string, unknown>) {
async run(model: string, params?: Record<string, unknown>, options?: AiOptions) {
if (!hub.projectKey) {
throw createError({
statusCode: 500,
Expand All @@ -55,7 +55,7 @@ export function hubAI(): Ai {
headers: {
authorization: `Bearer ${hub.userToken}`
},
body: { model, params },
body: { model, params, options },
responseType: params?.stream ? 'stream' : undefined
}).catch(handleProxyError)
}
Expand Down Expand Up @@ -97,9 +97,9 @@ export function proxyHubAI(projectUrl: string, secretKey?: string, headers?: Hea
}
})
return {
async run(model: string, params?: Record<string, unknown>) {
async run(model: string, params?: Record<string, unknown>, options?: AiOptions) {
return aiAPI('/run', {
body: { model, params },
body: { model, params, options },
responseType: params?.stream ? 'stream' : undefined
}).catch(handleProxyError)
}
Expand All @@ -111,14 +111,15 @@ async function handleProxyError(err: H3Error) {
if (import.meta.dev && err.statusCode === 403) {
console.warn('It seems that your Cloudflare API token does not have the `Worker AI` permission.\nOpen `https://dash.cloudflare.com/profile/api-tokens` and edit your NuxtHub token.\nAdd the `Account > Worker AI > Read` permission to your token and save it.')
}
let data = err.data || {}
if (typeof (err as any).response?.json === 'function') {
let data = err.data
if (!err.data && typeof (err as any).response?.json === 'function') {
data = (await (err as any).response.json())?.data || {}
}
throw createError({
statusCode: err.statusCode,
statusCode: data?.statusCode || err.statusCode,
statusMessage: data?.statusMessage || err.statusMessage,
// @ts-expect-error not aware of data property
message: err.data?.message || err.message,
message: data?.message || err.message,
data
})
}