Skip to content

Feat: LLM Generation Client Returns Json #570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ bytes = "1.10.1"
rand = "0.9.0"
indoc = "2.0.6"
owo-colors = "4.2.0"
json5 = "0.4.1"
aws-config = "1.6.2"
aws-sdk-s3 = "1.85.0"
aws-sdk-sqs = "1.67.0"
76 changes: 18 additions & 58 deletions src/llm/anthropic.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use crate::llm::{
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat,
ToJsonSchemaOptions,
};
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use json5;
use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat};
use anyhow::{Result, bail, Context};
use crate::llm::prompt_utils::STRICT_JSON_PROMPT;
use serde_json::Value;

use crate::api_bail;
Expand Down Expand Up @@ -48,9 +45,11 @@ impl LlmGenerationClient for Client {
});

// Add system prompt as top-level field if present (required)
if let Some(system) = request.system_prompt {
payload["system"] = serde_json::json!(system);
let mut system_prompt = request.system_prompt.unwrap_or_default();
if matches!(request.output_format, Some(OutputFormat::JsonSchema { .. })) {
system_prompt = format!("{STRICT_JSON_PROMPT}\n\n{system_prompt}").into();
}
payload["system"] = serde_json::json!(system_prompt);

// Extract schema from output_format, error if not JsonSchema
let schema = match request.output_format.as_ref() {
Expand All @@ -67,69 +66,30 @@ impl LlmGenerationClient for Client {

let encoded_api_key = encode(&self.api_key);

let resp = self
.client
let resp = self.client
.post(url)
.header("x-api-key", encoded_api_key.as_ref())
.header("anthropic-version", "2023-06-01")
.json(&payload)
.send()
.await
.context("HTTP error")?;
let mut resp_json: Value = resp.json().await.context("Invalid JSON")?;
let resp_json: Value = resp.json().await.context("Invalid JSON")?;
if let Some(error) = resp_json.get("error") {
bail!("Anthropic API error: {:?}", error);
}

// Debug print full response
// println!("Anthropic API full response: {resp_json:?}");

let resp_content = &resp_json["content"];
let tool_name = "report_result";
let mut extracted_json: Option<Value> = None;
if let Some(array) = resp_content.as_array() {
for item in array {
if item.get("type") == Some(&Value::String("tool_use".to_string()))
&& item.get("name") == Some(&Value::String(tool_name.to_string()))
{
if let Some(input) = item.get("input") {
extracted_json = Some(input.clone());
break;
}
}
}
}
let text = if let Some(json) = extracted_json {
// Try strict JSON serialization first
serde_json::to_string(&json)?
} else {
// Fallback: try text if no tool output found
match &mut resp_json["content"][0]["text"] {
Value::String(s) => {
// Try strict JSON parsing first
match serde_json::from_str::<serde_json::Value>(s) {
Ok(_) => std::mem::take(s),
Err(e) => {
// Try permissive json5 parsing as fallback
match json5::from_str::<serde_json::Value>(s) {
Ok(value) => {
println!("[Anthropic] Used permissive JSON5 parser for output");
serde_json::to_string(&value)?
},
Err(e2) => return Err(anyhow::anyhow!(format!("No structured tool output or text found in response, and permissive JSON5 parsing also failed: {e}; {e2}")))
}
}
}
}
_ => {
return Err(anyhow::anyhow!(
"No structured tool output or text found in response"
))
}
}
// Extract the text response
let text = match resp_json["content"][0]["text"].as_str() {
Some(s) => s.to_string(),
None => bail!("No text in response"),
};

Ok(LlmGenerateResponse { text })
// Try to parse as JSON
match serde_json::from_str::<serde_json::Value>(&text) {
Ok(val) => Ok(LlmGenerateResponse::Json(val)),
Err(_) => Ok(LlmGenerateResponse::Text(text)),
}
}

fn json_schema_options(&self) -> ToJsonSchemaOptions {
Expand Down
42 changes: 24 additions & 18 deletions src/llm/gemini.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use crate::api_bail;
use crate::llm::{
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat,
ToJsonSchemaOptions,
};
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat};
use anyhow::{Result, bail, Context};
use serde_json::Value;
use crate::api_bail;
use urlencoding::encode;
use crate::llm::prompt_utils::STRICT_JSON_PROMPT;

pub struct Client {
model: String,
Expand Down Expand Up @@ -60,11 +58,14 @@ impl LlmGenerationClient for Client {

// Prepare payload
let mut payload = serde_json::json!({ "contents": contents });
if let Some(system) = request.system_prompt {
payload["systemInstruction"] = serde_json::json!({
"parts": [ { "text": system } ]
});
}
if let Some(mut system) = request.system_prompt {
if matches!(request.output_format, Some(OutputFormat::JsonSchema { .. })) {
system = format!("{STRICT_JSON_PROMPT}\n\n{system}").into();
}
payload["systemInstruction"] = serde_json::json!({
"parts": [ { "text": system } ]
});
}

// If structured output is requested, add schema and responseMimeType
if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format {
Expand All @@ -79,13 +80,10 @@ impl LlmGenerationClient for Client {
let api_key = &self.api_key;
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
encode(&self.model),
encode(api_key)
encode(&self.model), encode(api_key)
);

let resp = self
.client
.post(&url)
let resp = self.client.post(&url)
.json(&payload)
.send()
.await
Expand All @@ -102,7 +100,15 @@ impl LlmGenerationClient for Client {
_ => bail!("No text in response"),
};

Ok(LlmGenerateResponse { text })
// If output_format is JsonSchema, try to parse as JSON
if let Some(OutputFormat::JsonSchema { .. }) = request.output_format {
match serde_json::from_str::<serde_json::Value>(&text) {
Ok(val) => Ok(LlmGenerateResponse::Json(val)),
Err(_) => Ok(LlmGenerateResponse::Text(text)),
}
} else {
Ok(LlmGenerateResponse::Text(text))
}
}

fn json_schema_options(&self) -> ToJsonSchemaOptions {
Expand All @@ -113,4 +119,4 @@ impl LlmGenerationClient for Client {
top_level_must_be_object: true,
}
}
}
}
8 changes: 5 additions & 3 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub struct LlmSpec {
model: String,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum OutputFormat<'a> {
JsonSchema {
name: Cow<'a, str>,
Expand All @@ -38,8 +38,9 @@ pub struct LlmGenerateRequest<'a> {
}

#[derive(Debug)]
pub struct LlmGenerateResponse {
pub text: String,
pub enum LlmGenerateResponse {
Json(serde_json::Value),
Text(String),
}

#[async_trait]
Expand All @@ -56,6 +57,7 @@ mod anthropic;
mod gemini;
mod ollama;
mod openai;
mod prompt_utils;

pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
let client = match spec.api_type {
Expand Down
19 changes: 15 additions & 4 deletions src/llm/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::LlmGenerationClient;
use anyhow::Result;
use async_trait::async_trait;
use crate::llm::prompt_utils::STRICT_JSON_PROMPT;
use schemars::schema::SchemaObject;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -52,6 +53,10 @@ impl LlmGenerationClient for Client {
&self,
request: super::LlmGenerateRequest<'req>,
) -> Result<super::LlmGenerateResponse> {
let mut system_prompt = request.system_prompt.unwrap_or_default();
if matches!(request.output_format, Some(super::OutputFormat::JsonSchema { .. })) {
system_prompt = format!("{STRICT_JSON_PROMPT}\n\n{system_prompt}").into();
}
let req = OllamaRequest {
model: &self.model,
prompt: request.user_prompt.as_ref(),
Expand All @@ -60,7 +65,7 @@ impl LlmGenerationClient for Client {
OllamaFormat::JsonSchema(schema.as_ref())
},
),
system: request.system_prompt.as_ref().map(|s| s.as_ref()),
system: Some(&system_prompt),
stream: Some(false),
};
let res = self
Expand All @@ -71,9 +76,15 @@ impl LlmGenerationClient for Client {
.await?;
let body = res.text().await?;
let json: OllamaResponse = serde_json::from_str(&body)?;
Ok(super::LlmGenerateResponse {
text: json.response,
})
// Check if output_format is JsonSchema, try to parse as JSON
if let Some(super::OutputFormat::JsonSchema { .. }) = request.output_format {
match serde_json::from_str::<serde_json::Value>(&json.response) {
Ok(val) => Ok(super::LlmGenerateResponse::Json(val)),
Err(_) => Ok(super::LlmGenerateResponse::Text(json.response)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to just pop up the error in this case. When the caller wants JSON, returning a text will leave more burden on the caller side: more cases need to be handled, and when they perform exactly the same way to parse (using serde_json::from_str()) they will fail again.

}
} else {
Ok(super::LlmGenerateResponse::Text(json.response))
}
}

fn json_schema_options(&self) -> super::ToJsonSchemaOptions {
Expand Down
16 changes: 13 additions & 3 deletions src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ impl LlmGenerationClient for Client {
},
));

// Save output_format before it is moved.
let output_format = request.output_format.clone();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we only need to save a boolean here to indicate if JSON is expected. No need to clone the entire thing.

// Create the chat completion request
let request = CreateChatCompletionRequest {
let openai_request = CreateChatCompletionRequest {
model: self.model.clone(),
messages,
response_format: match request.output_format {
Expand All @@ -85,7 +87,7 @@ impl LlmGenerationClient for Client {
};

// Send request and get response
let response = self.client.chat().create(request).await?;
let response = self.client.chat().create(openai_request).await?;

// Extract the response text from the first choice
let text = response
Expand All @@ -95,7 +97,15 @@ impl LlmGenerationClient for Client {
.and_then(|choice| choice.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;

Ok(super::LlmGenerateResponse { text })
// If output_format is JsonSchema, try to parse as JSON
if let Some(super::OutputFormat::JsonSchema { .. }) = output_format {
match serde_json::from_str::<serde_json::Value>(&text) {
Ok(val) => Ok(super::LlmGenerateResponse::Json(val)),
Err(_) => Ok(super::LlmGenerateResponse::Text(text)),
}
} else {
Ok(super::LlmGenerateResponse::Text(text))
}
}

fn json_schema_options(&self) -> super::ToJsonSchemaOptions {
Expand Down
4 changes: 4 additions & 0 deletions src/llm/prompt_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// Shared prompt utilities for LLM clients
// Only import this in clients that require strict JSON output instructions (e.g., Anthropic, Gemini, Ollama)

pub const STRICT_JSON_PROMPT: &str = "IMPORTANT: Output ONLY valid JSON that matches the schema. Do NOT say anything else. Do NOT explain. Do NOT preface. Do NOT add comments. If you cannot answer, output an empty JSON object: {}.";
9 changes: 6 additions & 3 deletions src/ops/functions/extract_by_llm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::prelude::*;

use crate::llm::{
new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat,
new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmGenerateResponse, LlmSpec, OutputFormat,
};
use crate::ops::sdk::*;
use base::json_schema::build_json_schema;
Expand Down Expand Up @@ -83,7 +83,10 @@ impl SimpleFunctionExecutor for Executor {
}),
};
let res = self.client.generate(req).await?;
let json_value: serde_json::Value = serde_json::from_str(res.text.as_str())?;
let json_value = match res {
LlmGenerateResponse::Json(val) => val,
LlmGenerateResponse::Text(text) => serde_json::from_str(&text)?,
};
let value = self.value_extractor.extract_value(json_value)?;
Ok(value)
}
Expand Down Expand Up @@ -124,4 +127,4 @@ impl SimpleFunctionFactoryBase for Factory {
) -> Result<Box<dyn SimpleFunctionExecutor>> {
Ok(Box::new(Executor::new(spec, resolved_input_schema).await?))
}
}
}