Skip to content

Commit

Permalink
feat: Add support for Mistral API in the Bedrock module
Browse files Browse the repository at this point in the history
- Implemented MistralClient for interacting with the Mistral API in the Bedrock service
- Added MistralRequestBuilder and MistralResponse models
- Added examples for using the Mistral API
- Updated the README with information about the Mistral API support
- Bumped the version to 0.1.7
  • Loading branch information
raphaelmansuy committed Apr 5, 2024
1 parent 49b0ed9 commit 92867a8
Show file tree
Hide file tree
Showing 16 changed files with 566 additions and 305 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "hiramu"
version = "0.1.6"
version = "0.1.7"
edition = "2021"
license = "MIT"
description = "A Rust AI Engineering Toolbox"
Expand Down
File renamed without changes.
120 changes: 84 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,37 +1,93 @@
# Hiramu

Hiramu is a powerful and flexible Rust library that provides a high-level interface for interacting with various AI models and APIs, including Ollama and Bedrock. It simplifies the process of generating text, engaging in chat conversations, and working with different AI models.
Hiramu is a powerful and flexible Rust library that provides a high-level interface for interacting with various AI models and APIs, including Ollama and AWS Bedrock.

It simplifies the process of generating text, engaging in chat conversations, and working with different AI models.

## Features

- Easy-to-use interfaces for generating text and engaging in chat conversations with AI models
- Support for Ollama and Bedrock AI services
- Convenient interface for Claude and Mistral for AWS Bedrock
- Asynchronous and streaming responses for efficient handling of large outputs
- Customizable options for fine-tuning the behavior of AI models
- Comprehensive error handling and informative error messages
- Well-documented code with examples and explanations


## Getting Started

To start using Hiramu in your Rust project, add the following to your `Cargo.toml` file:

```toml
[dependencies]
hiramu = "0.1.X"
hiramu = "0.1.7"
```

Then, import the necessary modules and types in your Rust code:
## Examples

### Generating Text with Mistral

```rust
use hiramu::ollama::ollama_client::OllamaClient;
use hiramu::ollama::model::{GenerateRequest, GenerateRequestBuilder, GenerateResponse};
use hiramu::bedrock::bedrock_client::{BedrockClient, BedrockClientOptions};
use hiramu::bedrock::models::claude::claude_client::{ClaudeClient, ClaudeOptions};
use hiramu::bedrock::models::claude::claude_request_message::{ChatOptions, ConversationRequest, Message};
use hiramu::bedrock::models::mistral::mistral_client::{MistralClient, MistralOptions};
use hiramu::bedrock::models::mistral::mistral_request_message::MistralRequestBuilder;
use hiramu::bedrock::model_info::{ModelInfo, ModelName};

#[tokio::main]
async fn main() {
let mistral_options = MistralOptions::new()
.profile_name("bedrock")
.region("us-west-2");

let client = MistralClient::new(mistral_options).await;

let request = MistralRequestBuilder::new("<s>[INST] What is the capital of France?[/INST]".to_string())
.max_tokens(200)
.temperature(0.8)
.build();

let model_id = ModelInfo::from_model_name(ModelName::MistralMixtral8X7BInstruct0x);
let response = client.generate(model_id, &request).await.unwrap();

println!("Response: {:?}", response.outputs.text);
}
```

## Examples
### Streaming Text Generation with Mistral

```rust
use futures::stream::StreamExt;
use hiramu::bedrock::models::mistral::mistral_client::{MistralClient, MistralOptions};
use hiramu::bedrock::models::mistral::mistral_request_message::MistralRequestBuilder;
use hiramu::bedrock::model_info::{ModelInfo, ModelName};

#[tokio::main]
async fn main() {
let mistral_options = MistralOptions::new()
.profile_name("bedrock")
.region("us-west-2");

let client = MistralClient::new(mistral_options).await;

let request = MistralRequestBuilder::new("<s>[INST] What is the capital of France?[/INST]".to_string())
.max_tokens(200)
.temperature(0.8)
.build();

let model_id = ModelInfo::from_model_name(ModelName::MistralMixtral8X7BInstruct0x);
let mut stream = client.generate_with_stream(model_id, &request).await.unwrap();

while let Some(result) = stream.next().await {
match result {
Ok(response) => {
println!("Response: {:?}", response.outputs.text);
}
Err(err) => {
eprintln!("Error: {:?}", err);
}
}
}
}
```

### Generating Text with Ollama

Expand Down Expand Up @@ -63,6 +119,7 @@ async fn main() {
```rust
use hiramu::bedrock::models::claude::claude_client::{ClaudeClient, ClaudeOptions};
use hiramu::bedrock::models::claude::claude_request_message::{ChatOptions, ConversationRequest, Message};
use hiramu::bedrock::model_info::{ModelInfo, ModelName};

#[tokio::main]
async fn main() {
Expand All @@ -79,7 +136,8 @@ async fn main() {

let chat_options = ChatOptions::default()
.with_temperature(0.7)
.with_max_tokens(100);
.with_max_tokens(100)
.with_model_id(ModelInfo::from_model_name(ModelName::AnthropicClaudeHaiku1x));

let response_stream = client
.chat_with_stream(&conversation_request, &chat_options)
Expand Down Expand Up @@ -231,32 +289,21 @@ async fn main() {
}
```

Here's a paragraph explaining how to use Embeddings in the Ollama Rust library for a README.md file:

## Embeddings

The Ollama library provides functionality to generate embeddings for a given text prompt. Embeddings are dense vector representations of text that capture semantic meaning and can be used for various downstream tasks such as semantic search, clustering, and classification. To generate embeddings, you can use the `OllamaClient::embeddings` method. First, create an instance of `EmbeddingsRequestBuilder` by providing the model name and the text prompt. Optionally, you can specify additional options and a keep-alive duration. Then, call the `build` method to create an `EmbeddingsRequest` and pass it to the `embeddings` method of the `OllamaClient`. The method returns an `EmbeddingsResponse` containing the generated embedding as a vector of floating-point values. Here's an example:

```rust
use ollama::{OllamaClient, EmbeddingsRequestBuilder};

let client = OllamaClient::new("http://localhost:11434".to_string());
let request = EmbeddingsRequestBuilder::new(
"nomic-embed-text".to_string(),
"Here is an article about llamas...".to_string(),
)
.options(serde_json::json!({ "temperature": 0.8 }))
.keep_alive("10m".to_string())
.build();

let response = client.embeddings(request).await.unwrap();
println!("Embeddings: {:?}", response.embedding);
```

This code snippet demonstrates how to create an `EmbeddingsRequestBuilder`, set the model name, prompt, options, and keep-alive duration, and then build the request. The `embeddings` method is called with the request, and the resulting `EmbeddingsResponse` contains the generated embedding.

## Examples

Here is a table with a description for each example:

| Example | Path | Description |
|---------|------|--------------|
| `demo_ollama` | [src/examples/demo_ollama.rs](src/examples/demo_ollama.rs) | A simple example that demonstrates how to use the Ollama API to generate responses to chat messages. |
| `demo_bedrock_raw_generate` | [src/examples/demo_bedrock_raw_generate.rs](src/examples/demo_bedrock_raw_generate.rs) | Demonstrates how to generate a raw response from the Bedrock service using the `generate_raw` method. |
| `demo_bedrock_raw_stream` | [src/examples/demo_bedrock_raw_stream.rs](src/examples/demo_bedrock_raw_stream.rs) | Demonstrates how to generate a raw stream of responses from the Bedrock service using the `generate_raw_stream` method. |
| `demo_bedrock_raw_mistral` | [src/examples/demo_bedrock_raw_mistral.rs](src/examples/demo_bedrock_raw_mistral.rs) | Demonstrates how to generate a raw stream of responses from the Mistral model in the Bedrock service. |
| `demo_claude_chat` | [src/examples/demo_claude_chat.rs](src/examples/demo_claude_chat.rs) | Demonstrates how to use the Claude model in the Bedrock service to generate a chat response. |
| `demo_claude_chat_stream` | [src/examples/demo_claude_chat_stream.rs](src/examples/demo_claude_chat_stream.rs) | Demonstrates how to use the Claude model in the Bedrock service to generate a stream of chat responses. |
| `demo_claude_multimedia` | [src/examples/demo_claude_multimedia.rs](src/examples/demo_claude_multimedia.rs) | Demonstrates how to use the Claude model in the Bedrock service to generate a response based on text and an image. |
| `demo_ollama_embedding` | [src/examples/demo_ollama_embedding.rs](src/examples/demo_ollama_embedding.rs) | Demonstrates how to use the Ollama API to generate text embeddings. |
| `demo_mistral_stream` | [src/examples/demo_mistral_stream.rs](src/examples/demo_mistral_stream.rs) | Demonstrates how to use the Mistral model in the Bedrock service to generate a stream of responses.

## Contributing

Expand All @@ -272,13 +319,14 @@ To contribute to the project, follow these steps:

## License

Hiramu is licensed under the [MIT License].
Hiramu is licensed under the [MIT License](./LICENCE).

## Acknowledgements

Hiramu is built on top of the following libraries and APIs:

- [Ollama](https://ollama.com/)
- [Bedrock](https://bedrock.com/)
- [reqwest](https://docs.rs/reqwest)
- [tokio](https://tokio.rs/)
- [serde](https://serde.rs/)
Expand Down
6 changes: 6 additions & 0 deletions REVISION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Revision History

## [2024-04-05]

### 0.1.7
- Added support for the Mistral API in the Bedrock module.
2 changes: 1 addition & 1 deletion documentation/TODO.Md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[X] - Implement Bedrock Client
[X] - Implement Bedrock MultiModal in Chat, Image support
[ ] - Add support to embedding models with Ollama
[X] - Add support to embedding models with Ollama
[ ] - Add support to embedding models with Bedrocks
[ ] - Add more Tests and examples
[ ] - Expose the Library for Python / NodeJs
Expand Down
5 changes: 5 additions & 0 deletions src/bedrock/model_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub enum ModelName {
MetaLlama2Chat70B1x,
MistralMistral7BInstruct0x,
MistralMixtral8X7BInstruct0x,
MistralLarge,
StabilityStableDiffusionXL0x,
StabilityStableDiffusionXL1x,
}
Expand Down Expand Up @@ -107,6 +108,10 @@ impl ModelInfo {
name: ModelName::MistralMixtral8X7BInstruct0x,
text: "mistral.mixtral-8x7b-instruct-v0:1",
},
ModelInfo {
name: ModelName::MistralLarge,
text: "mistral.mistral-large-2402-v1:0",
},
ModelInfo {
name: ModelName::StabilityStableDiffusionXL0x,
text: "stability.stable-diffusion-xl-v0",
Expand Down
27 changes: 27 additions & 0 deletions src/bedrock/models/mistral/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use crate::bedrock::error::BedrockError;
use serde_json::Error as SerdeJsonError;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum MistralError {
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),

#[error("JSON error: {0}")]
Json(#[from] SerdeJsonError),

#[error("I/O error: {0}")]
Io(#[from] std::io::Error),

#[error("UTF-8 error: {0}")]
Utf8(#[from] std::str::Utf8Error),

#[error("Invalid response: {0}")]
InvalidResponse(String),

#[error("Unknown error: {0}")]
Unknown(String),

#[error("Bedrock error: {0}")]
Bedrock(#[from] BedrockError),
}
125 changes: 125 additions & 0 deletions src/bedrock/models/mistral/mistral_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use crate::bedrock::bedrock_client::{BedrockClient, BedrockClientOptions};
use crate::bedrock::models::mistral::error::MistralError;
use crate::bedrock::models::mistral::mistral_request_message::{MistralRequest, MistralResponse};
use futures::stream::Stream;
use futures::TryStreamExt;

pub type MistralOptions = BedrockClientOptions;

pub struct MistralClient {
client: BedrockClient,
}

impl MistralClient {
/// Constructs a new `MistralClient`.
pub async fn new(options: MistralOptions) -> Self {
Self {
client: BedrockClient::new(options).await,
}
}

/// Generates a response from the Mistral model.
pub async fn generate(
&self,
model_id: String,
request: &MistralRequest,
) -> Result<MistralResponse, MistralError> {
let payload = serde_json::to_value(request).map_err(MistralError::Json)?;

let response = self.client.generate_raw(model_id, payload).await?;

let mistral_response = serde_json::from_value(response).map_err(MistralError::Json)?;
Ok(mistral_response)
}

/// Generates a stream of responses from the Mistral model.
pub async fn generate_with_stream(
&self,
model_id: String,
request: &MistralRequest,
) -> Result<impl Stream<Item = Result<MistralResponse, MistralError>>, MistralError> {
let payload = serde_json::to_value(request).map_err(MistralError::Json)?;

let response = self.client.generate_raw_stream(model_id, payload).await?;


Ok(response
.map_ok(|value| serde_json::from_value(value).map_err(MistralError::Json))
.map_err(|err| MistralError::Bedrock(err))
.and_then(futures::future::ready))
}
}


#[cfg(test)]
mod tests {
use super::*;
use crate::bedrock::{models::mistral::mistral_request_message::MistralRequestBuilder, ModelInfo};
use futures::stream::StreamExt;

#[tokio::test]
async fn test_generate() {
let options = MistralOptions::new().profile_name("bedrock").region("us-west-2");
let client = MistralClient::new(options).await;

let request = MistralRequestBuilder::new("<s>[INST] What is the capital of France ?[/INST]".to_string())
.max_tokens(200)
.temperature(0.8)
.build();

let model_name = ModelInfo::from_model_name(crate::bedrock::ModelName::MistralMixtral8X7BInstruct0x);

let response = client.generate(model_name, &request).await;

let response = match response {
Ok(response) => response,
Err(err) => panic!("Error: {:?}", err),
};

println!("Response: {:?}", response.outputs[0].text.to_string());

assert!(!response.outputs.is_empty());
}

#[tokio::test]
async fn test_generate_with_stream() {
let options = MistralOptions::new().profile_name("bedrock").region("us-west-2");
let client = MistralClient::new(options).await;

let request = MistralRequestBuilder::new("<s>[INST] What is the capital of France ?[/INST]".to_string())
.max_tokens(200)
.temperature(0.8)
.build();

let model_name = ModelInfo::from_model_name(crate::bedrock::ModelName::MistralMixtral8X7BInstruct0x);

// display the request as a pretty-printed JSON string
let display_request = serde_json::to_string_pretty(&request).unwrap();
println!("Request: {}", display_request);



let mut stream = client
.generate_with_stream("mistral.mistral-7b-instruct-v0:2".to_string(), &request)
.await
.unwrap();

let mut response_text = String::new();
while let Some(result) = stream.next().await {
match result {
Ok(response) => {
println!("Response: {:?}", response.outputs[0].text.to_string());
response_text.push_str(&response.outputs[0].text);
}
Err(err) => {
panic!("Error: {:?}", err);
}
}
}

assert!(!response_text.is_empty());

}


}
Loading

0 comments on commit 92867a8

Please sign in to comment.