Skip to content

Commit 4b4f6f4

Browse files
committed
feat: Add AWS Bedrock LLM Support
This commit adds support for AWS Bedrock for LLM parsing. The implementation follows the approach of other LLM providers and uses the `BEDROCK_API_KEY` and `BEDROCK_REGION` environment variables for authentication. This resolves issue #1162.
1 parent bd1fde3 commit 4b4f6f4

File tree

4 files changed

+215
-1
lines changed

4 files changed

+215
-1
lines changed

python/cocoindex/llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class LlmApiType(Enum):
1414
OPEN_ROUTER = "OpenRouter"
1515
VOYAGE = "Voyage"
1616
VLLM = "Vllm"
17+
BEDROCK = "Bedrock"
1718

1819

1920
@dataclass

python/cocoindex/tests/test_engine_value.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,25 @@ def test_full_roundtrip_vector_numeric_types() -> None:
10641064
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
10651065

10661066

1067+
def test_llm_api_type_bedrock() -> None:
1068+
"""Test that LlmApiType.BEDROCK is available and works."""
1069+
from cocoindex.llm import LlmApiType, LlmSpec
1070+
1071+
# Test enum availability
1072+
assert hasattr(LlmApiType, "BEDROCK")
1073+
assert LlmApiType.BEDROCK.value == "Bedrock"
1074+
1075+
# Test LlmSpec creation with Bedrock
1076+
spec = LlmSpec(
1077+
api_type=LlmApiType.BEDROCK, model="us.anthropic.claude-3-5-haiku-20241022-v1:0"
1078+
)
1079+
1080+
assert spec.api_type == LlmApiType.BEDROCK
1081+
assert spec.model == "us.anthropic.claude-3-5-haiku-20241022-v1:0"
1082+
assert spec.address is None
1083+
assert spec.api_config is None
1084+
1085+
10671086
def test_full_roundtrip_vector_of_vector() -> None:
10681087
"""Test full roundtrip for vector of vector."""
10691088
value_f32 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)

src/llm/bedrock.rs

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
use crate::prelude::*;
2+
use base64::prelude::*;
3+
4+
use crate::llm::{
5+
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
6+
ToJsonSchemaOptions, detect_image_mime_type,
7+
};
8+
use anyhow::Context;
9+
use urlencoding::encode;
10+
11+
pub struct Client {
12+
api_key: String,
13+
region: String,
14+
client: reqwest::Client,
15+
}
16+
17+
impl Client {
18+
pub async fn new(address: Option<String>) -> Result<Self> {
19+
if address.is_some() {
20+
api_bail!("Bedrock doesn't support custom API address");
21+
}
22+
23+
let api_key = match std::env::var("BEDROCK_API_KEY") {
24+
Ok(val) => val,
25+
Err(_) => api_bail!("BEDROCK_API_KEY environment variable must be set"),
26+
};
27+
28+
// Default to us-east-1 if no region specified
29+
let region = std::env::var("BEDROCK_REGION").unwrap_or_else(|_| "us-east-1".to_string());
30+
31+
Ok(Self {
32+
api_key,
33+
region,
34+
client: reqwest::Client::new(),
35+
})
36+
}
37+
}
38+
39+
#[async_trait]
40+
impl LlmGenerationClient for Client {
41+
async fn generate<'req>(
42+
&self,
43+
request: LlmGenerateRequest<'req>,
44+
) -> Result<LlmGenerateResponse> {
45+
let mut user_content_parts: Vec<serde_json::Value> = Vec::new();
46+
47+
// Add image part if present
48+
if let Some(image_bytes) = &request.image {
49+
let base64_image = BASE64_STANDARD.encode(image_bytes.as_ref());
50+
let mime_type = detect_image_mime_type(image_bytes.as_ref())?;
51+
user_content_parts.push(serde_json::json!({
52+
"image": {
53+
"format": mime_type.split('/').nth(1).unwrap_or("png"),
54+
"source": {
55+
"bytes": base64_image,
56+
}
57+
}
58+
}));
59+
}
60+
61+
// Add text part
62+
user_content_parts.push(serde_json::json!({
63+
"text": request.user_prompt
64+
}));
65+
66+
let messages = vec![serde_json::json!({
67+
"role": "user",
68+
"content": user_content_parts
69+
})];
70+
71+
let mut payload = serde_json::json!({
72+
"messages": messages,
73+
"inferenceConfig": {
74+
"maxTokens": 4096
75+
}
76+
});
77+
78+
// Add system prompt if present
79+
if let Some(system) = request.system_prompt {
80+
payload["system"] = serde_json::json!([{
81+
"text": system
82+
}]);
83+
}
84+
85+
// Handle structured output using tool schema
86+
if let Some(OutputFormat::JsonSchema { schema, name }) = request.output_format.as_ref() {
87+
let schema_json = serde_json::to_value(schema)?;
88+
payload["toolConfig"] = serde_json::json!({
89+
"tools": [{
90+
"toolSpec": {
91+
"name": name,
92+
"description": format!("Extract structured data according to the schema"),
93+
"inputSchema": {
94+
"json": schema_json
95+
}
96+
}
97+
}]
98+
});
99+
}
100+
101+
// Construct the Bedrock Runtime API URL
102+
let url = format!(
103+
"https://bedrock-runtime.{}.amazonaws.com/model/{}/converse",
104+
self.region, request.model
105+
);
106+
107+
let encoded_api_key = encode(&self.api_key);
108+
109+
let resp = retryable::run(
110+
|| async {
111+
self.client
112+
.post(&url)
113+
.header(
114+
"Authorization",
115+
format!("Bearer {}", encoded_api_key.as_ref()),
116+
)
117+
.header("Content-Type", "application/json")
118+
.json(&payload)
119+
.send()
120+
.await?
121+
.error_for_status()
122+
},
123+
&retryable::HEAVY_LOADED_OPTIONS,
124+
)
125+
.await
126+
.context("Bedrock API error")?;
127+
128+
let resp_json: serde_json::Value = resp.json().await.context("Invalid JSON")?;
129+
130+
// Check for errors in the response
131+
if let Some(error) = resp_json.get("error") {
132+
bail!("Bedrock API error: {:?}", error);
133+
}
134+
135+
// Debug print full response (uncomment for debugging)
136+
// println!("Bedrock API full response: {resp_json:?}");
137+
138+
// Extract the response content
139+
let output = &resp_json["output"];
140+
let message = &output["message"];
141+
let content = &message["content"];
142+
143+
let text = if let Some(content_array) = content.as_array() {
144+
// Look for tool use first (structured output)
145+
let mut extracted_json: Option<serde_json::Value> = None;
146+
for item in content_array {
147+
if let Some(tool_use) = item.get("toolUse") {
148+
if let Some(input) = tool_use.get("input") {
149+
extracted_json = Some(input.clone());
150+
break;
151+
}
152+
}
153+
}
154+
155+
if let Some(json) = extracted_json {
156+
// Return the structured output as JSON
157+
serde_json::to_string(&json)?
158+
} else {
159+
// Fall back to text content
160+
let mut text_parts = Vec::new();
161+
for item in content_array {
162+
if let Some(text) = item.get("text") {
163+
if let Some(text_str) = text.as_str() {
164+
text_parts.push(text_str);
165+
}
166+
}
167+
}
168+
text_parts.join("")
169+
}
170+
} else {
171+
return Err(anyhow::anyhow!("No content found in Bedrock response"));
172+
};
173+
174+
Ok(LlmGenerateResponse { text })
175+
}
176+
177+
fn json_schema_options(&self) -> ToJsonSchemaOptions {
178+
ToJsonSchemaOptions {
179+
fields_always_required: false,
180+
supports_format: false,
181+
extract_descriptions: false,
182+
top_level_must_be_object: true,
183+
}
184+
}
185+
}

src/llm/mod.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub enum LlmApiType {
1818
Voyage,
1919
Vllm,
2020
VertexAi,
21+
Bedrock,
2122
}
2223

2324
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -106,6 +107,7 @@ pub trait LlmEmbeddingClient: Send + Sync {
106107
}
107108

108109
mod anthropic;
110+
mod bedrock;
109111
mod gemini;
110112
mod litellm;
111113
mod ollama;
@@ -134,6 +136,9 @@ pub async fn new_llm_generation_client(
134136
LlmApiType::Anthropic => {
135137
Box::new(anthropic::Client::new(address).await?) as Box<dyn LlmGenerationClient>
136138
}
139+
LlmApiType::Bedrock => {
140+
Box::new(bedrock::Client::new(address).await?) as Box<dyn LlmGenerationClient>
141+
}
137142
LlmApiType::LiteLlm => {
138143
Box::new(litellm::Client::new_litellm(address).await?) as Box<dyn LlmGenerationClient>
139144
}
@@ -169,7 +174,11 @@ pub async fn new_llm_embedding_client(
169174
}
170175
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
171176
as Box<dyn LlmEmbeddingClient>,
172-
LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic => {
177+
LlmApiType::OpenRouter
178+
| LlmApiType::LiteLlm
179+
| LlmApiType::Vllm
180+
| LlmApiType::Anthropic
181+
| LlmApiType::Bedrock => {
173182
api_bail!("Embedding is not supported for API type {:?}", api_type)
174183
}
175184
};

0 commit comments

Comments
 (0)