Skip to content

[Example]: Use serde, schemars to make structure output code easy #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/structured-outputs-schemars/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "structured-outputs-schemars"
version = "0.1.0"
edition = "2021"
publish = false

[dependencies]
async-openai = {path = "../../async-openai"}
serde_json = "1.0.127"
tokio = { version = "1.39.3", features = ["full"] }
schemars = "0.8.21"
serde = "1.0.130"
39 changes: 39 additions & 0 deletions examples/structured-outputs-schemars/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
## Intro

Based on the 'Chain of thought' example from https://platform.openai.com/docs/guides/structured-outputs/introduction?lang=curl

Using `schemars` and `serde` reduces coding effort.

## Output

```
cargo run | jq .
```

```
{
"final_answer": "x = -3.75",
"steps": [
{
"explanation": "Start with the equation given in the problem.",
"output": "8x + 7 = -23"
},
{
"explanation": "Subtract 7 from both sides to begin isolating the term with the variable x.",
"output": "8x + 7 - 7 = -23 - 7"
},
{
"explanation": "Simplify both sides. On the left-hand side, 7 - 7 equals 0, cancelling out, leaving the equation as follows.",
"output": "8x = -30"
},
{
"explanation": "Now, divide both sides by 8 to fully isolate x.",
"output": "8x/8 = -30/8"
},
{
"explanation": "Simplify the right side by performing the division. -30 divided by 8 is -3.75.",
"output": "x = -3.75"
}
]
}
```
97 changes: 97 additions & 0 deletions examples/structured-outputs-schemars/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use std::error::Error;

use async_openai::{
types::{
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestUserMessage, CreateChatCompletionRequestArgs, ResponseFormat,
ResponseFormatJsonSchema,
},
Client,
};
use schemars::{schema_for, JsonSchema};
use serde::{de::DeserializeOwned, Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct Step {
pub output: String,
pub explanation: String,
}

#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct MathReasoningResponse {
pub final_answer: String,
pub steps: Vec<Step>,
}

pub async fn structured_output<T: serde::Serialize + DeserializeOwned + JsonSchema>(
messages: Vec<ChatCompletionRequestMessage>,
) -> Result<Option<T>, Box<dyn Error>> {
let schema = schema_for!(T);
let schema_value = serde_json::to_value(&schema)?;
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
description: None,
name: "math_reasoning".into(),
schema: Some(schema_value),
strict: Some(true),
},
};

let request = CreateChatCompletionRequestArgs::default()
.max_tokens(512u32)
.model("gpt-4o-mini")
.messages(messages)
.response_format(response_format)
.build()?;

let client = Client::new();
let response = client.chat().create(request).await?;

for choice in response.choices {
if let Some(content) = choice.message.content {
return Ok(Some(serde_json::from_str::<T>(&content)?));
}
}

Ok(None)
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
// Expecting output schema
// let schema = json!({
// "type": "object",
// "properties": {
// "steps": {
// "type": "array",
// "items": {
// "type": "object",
// "properties": {
// "explanation": { "type": "string" },
// "output": { "type": "string" }
// },
// "required": ["explanation", "output"],
// "additionalProperties": false
// }
// },
// "final_answer": { "type": "string" }
// },
// "required": ["steps", "final_answer"],
// "additionalProperties": false
// });
if let Some(response) = structured_output::<MathReasoningResponse>(vec![
ChatCompletionRequestSystemMessage::from(
"You are a helpful math tutor. Guide the user through the solution step by step.",
)
.into(),
ChatCompletionRequestUserMessage::from("how can I solve 8x + 7 = -23").into(),
])
.await?
{
println!("{}", serde_json::to_string(&response).unwrap());
}

Ok(())
}