Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/src/ai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use serde::{Deserialize, Serialize};
use worker::{Env, Model, Request, Response, Result, StreamableModel};

use crate::SomeSharedData;

pub struct Llama4Scout17b16eInstruct;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed this seems like the right approach!


#[derive(Serialize)]
pub struct DefaultTextGenerationInput {
pub prompt: String,
}

#[derive(Deserialize)]
pub struct DefaultTextGenerationOutput {
pub response: String,
}

impl From<DefaultTextGenerationOutput> for Vec<u8> {
fn from(value: DefaultTextGenerationOutput) -> Self {
value.response.into_bytes()
}
}

impl Model for Llama4Scout17b16eInstruct {
const MODEL_NAME: &str = "@cf/meta/llama-4-scout-17b-16e-instruct";

type Input = DefaultTextGenerationInput;

type Output = DefaultTextGenerationOutput;
}

impl StreamableModel for Llama4Scout17b16eInstruct {}

const AI_TEST: &str = "AI_TEST";

#[worker::send]
pub async fn simple_ai_text_generation(
_: Request,
env: Env,
_data: SomeSharedData,
) -> Result<Response> {
let ai = env
.ai(AI_TEST)?
.run::<Llama4Scout17b16eInstruct>(DefaultTextGenerationInput {
prompt: "What is the answer to life the universe and everything?".to_owned(),
})
Comment on lines +44 to +46
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of DefaultTextGenerationInput, could we implement AiTextGenerationInput directly in core?

Along the lines of -

pub struct AiTextGenerationInput {
    pub prompt: Option<String>,
    pub raw: Option<bool>,
    pub stream: Option<bool>,
    pub max_tokens: Option<u32>,
    pub temperature: Option<f32>,
    pub top_p: Option<f32>,
    pub top_k: Option<u32>,
    pub seed: Option<u32>,
    pub repetition_penalty: Option<f32>,
    pub frequency_penalty: Option<f32>,
    pub presence_penalty: Option<f32>,
    pub messages: Option<Vec<RoleScopedChatInput>>,
    pub response_format: Option<AiTextGenerationResponseFormat>,
    pub tools: Option<serde_json::Value>, // For flexible union type
    pub functions: Option<Vec<AiTextGenerationFunctionsInput>>,
}

We don't even need all fields initially.

And similarly for https://workers-types.pages.dev/#AiTextGenerationOutput?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! that works. was trying to match the type definitions emitted by wrangler types, but that does seem more legible. Also would it make sense to mark the fields as private and use a builder pattern to construct the input?

Copy link
Contributor Author

@parzivale parzivale Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for tools I have a seperate pr I would like to make which is based on this one which handles tools as a new sub trait of model, the gist of it is you define tools like the following

type ToolRun = Rc<
    dyn Fn(
        &Env,
        Option<serde_json::Map<String, serde_json::Value>>,
    ) -> Pin<Box<dyn Future<Output = Result<String, Error>>>>,
>;

#[derive(Clone)]
pub struct Tool {
    pub name: &'static str,
    pub description: &'static str,
    pub run: ToolRun,
}

and then would would call them from ai with the following method

impl Ai {
  pub async fn run_with_tools<M: ToolsModel>(&self, input: &M::Input, tools: &[Tool]) -> Result<M::Ouput, Error>  {
  ...
  }
}

this would require a new trait called ToolsModel like this

pub trait ToolsModel: Model {}

The problem is I'm not sure this approach is generic enough to work for everyone. For instance I know normally working with axum you are expected to wrap env in an Arc which mean you no loner have access to this functionality

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As much as possible it would be good to follow Wasm bindgen semantics. I know we wrap all types currently, but I'm trying to move to a model where when possible we should use worker-sys types as the high level types.

So the wasm-bindgen pattern would be to define an imported type, and define all properties as getters and setters:

#[wasm_bindgen]
extern "C" {
    # [wasm_bindgen (extends = :: js_sys :: Object , js_name = AiTextGenerationInput)]
    #[derive(Debug, Clone, PartialEq, Eq)]
    #[doc = "Ai Text Generation"]
    pub type FetchEventInit;
    #[doc = "Prompt for generation"]
    #[wasm_bindgen(method, setter = "prompt")]
    pub fn set_prompt(this: &AiTextgenerationInput, prompt: Option<String>);
    #[wasm_bindgen(method, getter = "prompt")]
    pub fn get_prompt(this: &AiTextgenerationInput) -> Option<String>;
}

Then functions would just take &worker-sys::AiTextGenerationInput input directly. And we'd likely alias it as worker::ai::AiTextGenerationInput.

Copy link
Contributor Author

@parzivale parzivale Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! gotcha! I didn't realise you were referring to types in js, thanks for clarifying

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just looked into AiTextGeneration in the docs you provided and it seems AiTextGeneration is a typescript type not a class, so there is no imported type that we can bind to in the js namespace.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be fine - it's a structural type not a formal import.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry I just saw your comment now, I just made a commit where I kept the AiTextGenerationInput struct in worker rather than binding to js, if you don't like the api I can change it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm strongly trying to formalize our conventions, so let me think about this some more and follow-up soon. My initial sense is that serde perf is equivalent so we should prefer the direct interpretation convention, but I want to ensure I've got that right.

.await?;
Response::ok(ai.response)
}

#[worker::send]
pub async fn streaming_ai_text_generation(
_: Request,
env: Env,
_data: SomeSharedData,
) -> Result<Response> {
let stream = env
.ai(AI_TEST)?
.run_streaming::<Llama4Scout17b16eInstruct>(DefaultTextGenerationInput {
prompt: "What is the answer to life the universe and everything?".to_owned(),
})
.await?;

Response::from_stream(stream)
}
1 change: 1 addition & 0 deletions test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use worker::{console_log, event, js_sys, wasm_bindgen, Env, Result};
#[cfg(not(feature = "http"))]
use worker::{Request, Response};

mod ai;
mod alarm;
mod analytics_engine;
mod assets;
Expand Down
8 changes: 5 additions & 3 deletions test/src/router.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
alarm, analytics_engine, assets, auto_response, cache, container, counter, d1, durable, fetch,
form, js_snippets, kv, put_raw, queue, r2, request, secret_store, service, socket, sql_counter,
sql_iterator, user, ws, SomeSharedData, GLOBAL_STATE,
ai, alarm, analytics_engine, assets, auto_response, cache, container, counter, d1, durable,
fetch, form, js_snippets, kv, put_raw, queue, r2, request, secret_store, service, socket,
sql_counter, sql_iterator, user, ws, SomeSharedData, GLOBAL_STATE,
};
#[cfg(feature = "http")]
use std::convert::TryInto;
Expand Down Expand Up @@ -112,6 +112,8 @@ macro_rules! add_route (

macro_rules! add_routes (
($obj:ident) => {
add_route!($obj, get, "/ai", ai::simple_ai_text_generation);
add_route!($obj, get, "/ai/streaming", ai::streaming_ai_text_generation);
add_route!($obj, get, sync, "/request", request::handle_a_request);
add_route!($obj, get, "/analytics-engine", analytics_engine::handle_analytics_event);
add_route!($obj, get, "/async-request", request::handle_async_request);
Expand Down
11 changes: 11 additions & 0 deletions test/tests/ai.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { describe, expect, test } from "vitest";
import { mf, mfUrl } from "./mf";

async function runTest() {
let normal_response = await mf.dispatchFetch(`${mfUrl}/ai`);
expect(normal_response.status).toBe(200);

Check failure on line 6 in test/tests/ai.spec.ts

View workflow job for this annotation

GitHub Actions / Test

tests/ai.spec.ts

AssertionError: expected 404 to be 200 // Object.is equality - Expected + Received - 200 + 404 ❯ runTest tests/ai.spec.ts:6:34

let streaming_response = await mf.dispatchFetch(`${mfUrl}/ai/streaming`);
expect(streaming_response.status).toBe(200);
}
describe("ai", runTest);
25 changes: 14 additions & 11 deletions test/wrangler.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
name = "testing-rust-worker"
workers_dev = true
compatibility_date = "2025-09-23" # required
compatibility_date = "2025-09-23" # required
main = "build/worker/shim.mjs"

kv_namespaces = [
{ binding = "SOME_NAMESPACE", id = "SOME_NAMESPACE", preview_id = "SOME_NAMESPACE" },
{ binding = "FILE_SIZES", id = "FILE_SIZES", preview_id = "FILE_SIZES" },
{ binding = "SOME_NAMESPACE", id = "SOME_NAMESPACE", preview_id = "SOME_NAMESPACE" },
{ binding = "FILE_SIZES", id = "FILE_SIZES", preview_id = "FILE_SIZES" },
]

[vars]
Expand All @@ -22,14 +22,14 @@ service = "remote-service"

[durable_objects]
bindings = [
{ name = "COUNTER", class_name = "Counter" },
{ name = "ALARM", class_name = "AlarmObject" },
{ name = "PUT_RAW_TEST_OBJECT", class_name = "PutRawTestObject" },
{ name = "AUTO", class_name = "AutoResponseObject" },
{ name = "SQL_COUNTER", class_name = "SqlCounter" },
{ name = "SQL_ITERATOR", class_name = "SqlIterator" },
{ name = "MY_CLASS", class_name = "MyClass" },
{ name = "ECHO_CONTAINER", class_name = "EchoContainer" },
{ name = "COUNTER", class_name = "Counter" },
{ name = "ALARM", class_name = "AlarmObject" },
{ name = "PUT_RAW_TEST_OBJECT", class_name = "PutRawTestObject" },
{ name = "AUTO", class_name = "AutoResponseObject" },
{ name = "SQL_COUNTER", class_name = "SqlCounter" },
{ name = "SQL_ITERATOR", class_name = "SqlIterator" },
{ name = "MY_CLASS", class_name = "MyClass" },
{ name = "ECHO_CONTAINER", class_name = "EchoContainer" },
]

[[analytics_engine_datasets]]
Expand Down Expand Up @@ -84,3 +84,6 @@ secret_name = "secret-name"
class_name = "EchoContainer"
image = "./container-echo/Dockerfile"
max_instances = 1

[ai]
binding = "AI_TEST"
105 changes: 97 additions & 8 deletions worker/src/ai.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{self, Poll};

use crate::{env::EnvBinding, send::SendFuture};
use crate::{Error, Result};
use serde::de::DeserializeOwned;
use serde::Serialize;
use futures_util::io::{BufReader, Lines};
use futures_util::{ready, AsyncBufReadExt as _, Stream, StreamExt as _};
use js_sys::Reflect;
use pin_project::pin_project;
use serde::{de::DeserializeOwned, Serialize};
use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::JsFuture;
use wasm_streams::readable::IntoAsyncRead;
use worker_sys::Ai as AiSys;

pub mod models;

/// Enables access to Workers AI functionality.
#[derive(Debug)]
pub struct Ai(AiSys);
Expand All @@ -14,20 +24,27 @@ impl Ai {
/// Execute a Workers AI operation using the specified model.
/// Various forms of the input are documented in the Workers
/// AI documentation.
pub async fn run<T: Serialize, U: DeserializeOwned>(
&self,
model: impl AsRef<str>,
input: T,
) -> Result<U> {
pub async fn run<M: Model>(&self, input: M::Input) -> Result<M::Output> {
let fut = SendFuture::new(JsFuture::from(
self.0
.run(model.as_ref(), serde_wasm_bindgen::to_value(&input)?),
.run(M::MODEL_NAME, serde_wasm_bindgen::to_value(&input)?),
));
match fut.await {
Ok(output) => Ok(serde_wasm_bindgen::from_value(output)?),
Err(err) => Err(Error::from(err)),
}
}

pub async fn run_streaming<M: StreamableModel>(&self, input: M::Input) -> Result<AiStream<M>> {
let input = serde_wasm_bindgen::to_value(&input)?;
Reflect::set(&input, &JsValue::from_str("stream"), &JsValue::TRUE)?;

let fut = SendFuture::new(JsFuture::from(self.0.run(M::MODEL_NAME, input)));
let raw_stream = fut.await?.dyn_into::<web_sys::ReadableStream>()?;
let stream = wasm_streams::ReadableStream::from_raw(raw_stream).into_async_read();

Ok(AiStream::new(stream))
}
}

unsafe impl Sync for Ai {}
Expand Down Expand Up @@ -82,3 +99,75 @@ impl EnvBinding for Ai {
}
}
}

pub trait Model: 'static {
const MODEL_NAME: &str;
type Input: Serialize;
type Output: DeserializeOwned;
}

pub trait StreamableModel: Model {}

#[derive(Debug)]
#[pin_project]
pub struct AiStream<T: StreamableModel> {
#[pin]
inner: Lines<BufReader<IntoAsyncRead<'static>>>,
phantom: PhantomData<T>,
}

impl<T: StreamableModel> AiStream<T> {
pub fn new(stream: IntoAsyncRead<'static>) -> Self {
Self {
inner: BufReader::new(stream).lines(),
phantom: PhantomData,
}
}
}

impl<T: StreamableModel> Stream for AiStream<T> {
type Item = Result<T::Output>;

fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let string = match ready!(this.inner.poll_next_unpin(cx)) {
Some(item) => match item {
Ok(item) => {
if item.is_empty() {
match ready!(this.inner.poll_next_unpin(cx)) {
Some(item) => match item {
Ok(item) => item,
Err(err) => {
return Poll::Ready(Some(Err(err.into())));
}
},
None => {
return Poll::Ready(None);
}
}
} else {
item
}
}
Err(err) => {
return Poll::Ready(Some(Err(err.into())));
}
},
None => {
return Poll::Ready(None);
}
};

let string = if let Some(string) = string.strip_prefix("data: ") {
string
} else {
string.as_str()
};

if string == "[DONE]" {
return Poll::Ready(None);
}

Poll::Ready(Some(Ok(serde_json::from_str(string)?)))
}
}
102 changes: 102 additions & 0 deletions worker/src/ai/models.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use serde::{Deserialize, Serialize};

use crate::models::scoped_chat::RoleScopedChatInput;

pub mod llama_4_scout_17b_16e_instruct;

pub mod scoped_chat {
use serde::Serialize;

#[derive(Default, Serialize)]
#[serde(rename_all = "lowercase", untagged)]
pub enum Role {
#[default]
User,
Assistant,
System,
Tool,
Any(String),
}

#[derive(Default, Serialize)]
pub struct RoleScopedChatInput {
pub role: Role,
pub content: String,
pub name: Option<String>,
}

pub fn user(content: &str) -> RoleScopedChatInput {
RoleScopedChatInput {
role: Role::User,
content: content.to_owned(),
name: None,
}
}

pub fn assistant(content: &str) -> RoleScopedChatInput {
RoleScopedChatInput {
role: Role::Assistant,
content: content.to_owned(),
name: None,
}
}

pub fn system(content: &str) -> RoleScopedChatInput {
RoleScopedChatInput {
role: Role::System,
content: content.to_owned(),
name: None,
}
}

pub fn tool(content: &str) -> RoleScopedChatInput {
RoleScopedChatInput {
role: Role::Tool,
content: content.to_owned(),
name: None,
}
}
}

/// Default input object for text generating Ai
///
/// The type implements default so you do not have to specify all fields.
///
/// like so
///# fn main() {
/// AiTextGenerationInput {
/// prompt: Some("What is the answer to life the universe and everything?".to_owned()),
/// ..default()
/// }
///# ;}
///
// TODO add response_json, tool calling and function calling to the input
#[derive(Default, Serialize)]
pub struct AiTextGenerationInput {
pub prompt: Option<String>,
pub raw: Option<bool>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub seed: Option<u32>,
pub repetition_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub messages: Option<Vec<RoleScopedChatInput>>,
}

/// Default output object for text generating Ai
// TODO add tool call output support
#[derive(Default, Deserialize)]
pub struct UsageTags {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}

#[derive(Default, Deserialize)]
pub struct AiTextGenerationOutput {
pub response: Option<String>,
pub usage: Option<UsageTags>,
}
Loading
Loading