-
Notifications
You must be signed in to change notification settings - Fork 347
feat: create trait definitions for model and streamable model #833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8b9022f
d61679d
3607ee4
ac6f9fa
0a1989f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
use serde::{Deserialize, Serialize}; | ||
use worker::{Env, Model, Request, Response, Result, StreamableModel}; | ||
|
||
use crate::SomeSharedData; | ||
|
||
pub struct Llama4Scout17b16eInstruct; | ||
|
||
#[derive(Serialize)] | ||
pub struct DefaultTextGenerationInput { | ||
pub prompt: String, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct DefaultTextGenerationOutput { | ||
pub response: String, | ||
} | ||
|
||
impl From<DefaultTextGenerationOutput> for Vec<u8> { | ||
fn from(value: DefaultTextGenerationOutput) -> Self { | ||
value.response.into_bytes() | ||
} | ||
} | ||
|
||
impl Model for Llama4Scout17b16eInstruct { | ||
const MODEL_NAME: &str = "@cf/meta/llama-4-scout-17b-16e-instruct"; | ||
|
||
type Input = DefaultTextGenerationInput; | ||
|
||
type Output = DefaultTextGenerationOutput; | ||
} | ||
|
||
impl StreamableModel for Llama4Scout17b16eInstruct {} | ||
|
||
const AI_TEST: &str = "AI_TEST"; | ||
|
||
#[worker::send] | ||
pub async fn simple_ai_text_generation( | ||
_: Request, | ||
env: Env, | ||
_data: SomeSharedData, | ||
) -> Result<Response> { | ||
let ai = env | ||
.ai(AI_TEST)? | ||
.run::<Llama4Scout17b16eInstruct>(DefaultTextGenerationInput { | ||
prompt: "What is the answer to life the universe and everything?".to_owned(), | ||
}) | ||
Comment on lines
+44
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of Along the lines of - pub struct AiTextGenerationInput {
pub prompt: Option<String>,
pub raw: Option<bool>,
pub stream: Option<bool>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub seed: Option<u32>,
pub repetition_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub messages: Option<Vec<RoleScopedChatInput>>,
pub response_format: Option<AiTextGenerationResponseFormat>,
pub tools: Option<serde_json::Value>, // For flexible union type
pub functions: Option<Vec<AiTextGenerationFunctionsInput>>,
} We don't even need all fields initially. And similarly for https://workers-types.pages.dev/#AiTextGenerationOutput? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! that works. was trying to match the type definitions emitted by wrangler types, but that does seem more legible. Also would it make sense to mark the fields as private and use a builder pattern to construct the input? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also for tools I have a seperate pr I would like to make which is based on this one which handles tools as a new sub trait of model, the gist of it is you define tools like the following type ToolRun = Rc<
dyn Fn(
&Env,
Option<serde_json::Map<String, serde_json::Value>>,
) -> Pin<Box<dyn Future<Output = Result<String, Error>>>>,
>;
#[derive(Clone)]
pub struct Tool {
pub name: &'static str,
pub description: &'static str,
pub run: ToolRun,
} and then would would call them from ai with the following method impl Ai {
pub async fn run_with_tools<M: ToolsModel>(&self, input: &M::Input, tools: &[Tool]) -> Result<M::Ouput, Error> {
...
}
} this would require a new trait called ToolsModel like this pub trait ToolsModel: Model {} The problem is I'm not sure this approach is generic enough to work for everyone. For instance I know normally working with axum you are expected to wrap env in an Arc which mean you no loner have access to this functionality There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As much as possible it would be good to follow Wasm bindgen semantics. I know we wrap all types currently, but I'm trying to move to a model where when possible we should use worker-sys types as the high level types. So the wasm-bindgen pattern would be to define an imported type, and define all properties as getters and setters: #[wasm_bindgen]
extern "C" {
# [wasm_bindgen (extends = :: js_sys :: Object , js_name = AiTextGenerationInput)]
#[derive(Debug, Clone, PartialEq, Eq)]
#[doc = "Ai Text Generation"]
pub type FetchEventInit;
#[doc = "Prompt for generation"]
#[wasm_bindgen(method, setter = "prompt")]
pub fn set_prompt(this: &AiTextgenerationInput, prompt: Option<String>);
#[wasm_bindgen(method, getter = "prompt")]
pub fn get_prompt(this: &AiTextgenerationInput) -> Option<String>;
} Then functions would just take There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right! gotcha! I didn't realise you were referring to types in js, thanks for clarifying There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just looked into AiTextGeneration in the docs you provided and it seems AiTextGeneration is a typescript type not a class, so there is no imported type that we can bind to in the js namespace. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That should be fine - it's a structural type not a formal import. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah sorry I just saw your comment now, I just made a commit where I kept the AiTextGenerationInput struct in worker rather than binding to js, if you don't like the api I can change it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm strongly trying to formalize our conventions, so let me think about this some more and follow-up soon. My initial sense is that serde perf is equivalent so we should prefer the direct interpretation convention, but I want to ensure I've got that right. |
||
.await?; | ||
Response::ok(ai.response) | ||
} | ||
|
||
#[worker::send] | ||
pub async fn streaming_ai_text_generation( | ||
_: Request, | ||
env: Env, | ||
_data: SomeSharedData, | ||
) -> Result<Response> { | ||
let stream = env | ||
.ai(AI_TEST)? | ||
.run_streaming::<Llama4Scout17b16eInstruct>(DefaultTextGenerationInput { | ||
prompt: "What is the answer to life the universe and everything?".to_owned(), | ||
}) | ||
.await?; | ||
|
||
Response::from_stream(stream) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import { describe, expect, test } from "vitest"; | ||
import { mf, mfUrl } from "./mf"; | ||
|
||
async function runTest() { | ||
let normal_response = await mf.dispatchFetch(`${mfUrl}/ai`); | ||
expect(normal_response.status).toBe(200); | ||
|
||
let streaming_response = await mf.dispatchFetch(`${mfUrl}/ai/streaming`); | ||
expect(streaming_response.status).toBe(200); | ||
} | ||
describe("ai", runTest); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use crate::models::scoped_chat::RoleScopedChatInput; | ||
|
||
pub mod llama_4_scout_17b_16e_instruct; | ||
|
||
pub mod scoped_chat { | ||
use serde::Serialize; | ||
|
||
#[derive(Default, Serialize)] | ||
#[serde(rename_all = "lowercase", untagged)] | ||
pub enum Role { | ||
#[default] | ||
User, | ||
Assistant, | ||
System, | ||
Tool, | ||
Any(String), | ||
} | ||
|
||
#[derive(Default, Serialize)] | ||
pub struct RoleScopedChatInput { | ||
pub role: Role, | ||
pub content: String, | ||
pub name: Option<String>, | ||
} | ||
|
||
pub fn user(content: &str) -> RoleScopedChatInput { | ||
RoleScopedChatInput { | ||
role: Role::User, | ||
content: content.to_owned(), | ||
name: None, | ||
} | ||
} | ||
|
||
pub fn assistant(content: &str) -> RoleScopedChatInput { | ||
RoleScopedChatInput { | ||
role: Role::Assistant, | ||
content: content.to_owned(), | ||
name: None, | ||
} | ||
} | ||
|
||
pub fn system(content: &str) -> RoleScopedChatInput { | ||
RoleScopedChatInput { | ||
role: Role::System, | ||
content: content.to_owned(), | ||
name: None, | ||
} | ||
} | ||
|
||
pub fn tool(content: &str) -> RoleScopedChatInput { | ||
RoleScopedChatInput { | ||
role: Role::Tool, | ||
content: content.to_owned(), | ||
name: None, | ||
} | ||
} | ||
} | ||
|
||
/// Default input object for text generating Ai | ||
/// | ||
/// The type implements default so you do not have to specify all fields. | ||
/// | ||
/// like so | ||
///# fn main() { | ||
/// AiTextGenerationInput { | ||
/// prompt: Some("What is the answer to life the universe and everything?".to_owned()), | ||
/// ..default() | ||
/// } | ||
///# ;} | ||
/// | ||
// TODO add response_json, tool calling and function calling to the input | ||
#[derive(Default, Serialize)] | ||
pub struct AiTextGenerationInput { | ||
pub prompt: Option<String>, | ||
pub raw: Option<bool>, | ||
pub max_tokens: Option<u32>, | ||
pub temperature: Option<f32>, | ||
pub top_p: Option<f32>, | ||
pub top_k: Option<u32>, | ||
pub seed: Option<u32>, | ||
pub repetition_penalty: Option<f32>, | ||
pub frequency_penalty: Option<f32>, | ||
pub presence_penalty: Option<f32>, | ||
pub messages: Option<Vec<RoleScopedChatInput>>, | ||
} | ||
|
||
/// Default output object for text generating Ai | ||
// TODO add tool call output support | ||
#[derive(Default, Deserialize)] | ||
pub struct UsageTags { | ||
pub prompt_tokens: u32, | ||
pub completion_tokens: u32, | ||
pub total_tokens: u32, | ||
} | ||
|
||
#[derive(Default, Deserialize)] | ||
pub struct AiTextGenerationOutput { | ||
pub response: Option<String>, | ||
pub usage: Option<UsageTags>, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed this seems like the right approach!