From 92867a8a53df8b38a42e35d21786bce8728b72df Mon Sep 17 00:00:00 2001 From: Raphael MANSUY Date: Fri, 5 Apr 2024 21:07:50 +0200 Subject: [PATCH] feat: Add support for Mistral API in the Bedrock module - Implemented MistralClient for interacting with the Mistral API in the Bedrock service - Added MistralRequestBuilder and MistralResponse models - Added examples for using the Mistral API - Updated the README with information about the Mistral API support - Bumped the version to 0.1.7 --- Cargo.lock | 2 +- Cargo.toml | 2 +- LICENCE.md => LICENCE | 0 README.md | 120 +++++--- REVISION.md | 6 + documentation/TODO.Md | 2 +- src/bedrock/model_info.rs | 5 + src/bedrock/models/mistral/error.rs | 27 ++ src/bedrock/models/mistral/mistral_client.rs | 125 ++++++++ .../models/mistral/mistral_request_message.rs | 149 ++++++++++ src/bedrock/models/mistral/mod.rs | 12 + src/bedrock/models/mod.rs | 1 + src/examples/demo_bedrock_raw_mistral.rs | 85 ++++++ src/examples/demo_mistral_stream.rs | 65 +++++ src/examples/mod.rs | 4 +- src/lib.rs | 266 +----------------- 16 files changed, 566 insertions(+), 305 deletions(-) rename LICENCE.md => LICENCE (100%) create mode 100644 REVISION.md create mode 100644 src/bedrock/models/mistral/error.rs create mode 100644 src/bedrock/models/mistral/mistral_client.rs create mode 100644 src/bedrock/models/mistral/mistral_request_message.rs create mode 100644 src/bedrock/models/mistral/mod.rs create mode 100644 src/examples/demo_bedrock_raw_mistral.rs create mode 100644 src/examples/demo_mistral_stream.rs diff --git a/Cargo.lock b/Cargo.lock index 453f508..1c45f08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -806,7 +806,7 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hiramu" -version = "0.1.6" +version = "0.1.7" dependencies = [ "async-stream", "aws-config", diff --git a/Cargo.toml b/Cargo.toml index 8f6c598..cc9126c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hiramu" -version = "0.1.6" +version = "0.1.7" edition = "2021" license = "MIT" description = "A Rust AI Engineering Toolbox" diff --git a/LICENCE.md b/LICENCE similarity index 100% rename from LICENCE.md rename to LICENCE diff --git a/README.md b/README.md index 0e7e4d2..cfacf84 100644 --- a/README.md +++ b/README.md @@ -1,37 +1,93 @@ # Hiramu -Hiramu is a powerful and flexible Rust library that provides a high-level interface for interacting with various AI models and APIs, including Ollama and Bedrock. It simplifies the process of generating text, engaging in chat conversations, and working with different AI models. +Hiramu is a powerful and flexible Rust library that provides a high-level interface for interacting with various AI models and APIs, including Ollama and AWS Bedrock. + +It simplifies the process of generating text, engaging in chat conversations, and working with different AI models. ## Features - Easy-to-use interfaces for generating text and engaging in chat conversations with AI models - Support for Ollama and Bedrock AI services +- Convenient interface for Claude and Mistral for AWS Bedrock - Asynchronous and streaming responses for efficient handling of large outputs - Customizable options for fine-tuning the behavior of AI models - Comprehensive error handling and informative error messages - Well-documented code with examples and explanations - ## Getting Started To start using Hiramu in your Rust project, add the following to your `Cargo.toml` file: ```toml [dependencies] -hiramu = "0.1.X" +hiramu = "0.1.7" ``` -Then, import the necessary modules and types in your Rust code: +## Examples + +### Generating Text with Mistral ```rust -use hiramu::ollama::ollama_client::OllamaClient; -use hiramu::ollama::model::{GenerateRequest, GenerateRequestBuilder, GenerateResponse}; -use hiramu::bedrock::bedrock_client::{BedrockClient, BedrockClientOptions}; -use hiramu::bedrock::models::claude::claude_client::{ClaudeClient, ClaudeOptions}; -use hiramu::bedrock::models::claude::claude_request_message::{ChatOptions, ConversationRequest, Message}; +use hiramu::bedrock::models::mistral::mistral_client::{MistralClient, MistralOptions}; +use hiramu::bedrock::models::mistral::mistral_request_message::MistralRequestBuilder; +use hiramu::bedrock::model_info::{ModelInfo, ModelName}; + +#[tokio::main] +async fn main() { + let mistral_options = MistralOptions::new() + .profile_name("bedrock") + .region("us-west-2"); + + let client = MistralClient::new(mistral_options).await; + + let request = MistralRequestBuilder::new("[INST] What is the capital of France?[/INST]".to_string()) + .max_tokens(200) + .temperature(0.8) + .build(); + + let model_id = ModelInfo::from_model_name(ModelName::MistralMixtral8X7BInstruct0x); + let response = client.generate(model_id, &request).await.unwrap(); + + println!("Response: {:?}", response.outputs.text); +} ``` -## Examples +### Streaming Text Generation with Mistral + +```rust +use futures::stream::StreamExt; +use hiramu::bedrock::models::mistral::mistral_client::{MistralClient, MistralOptions}; +use hiramu::bedrock::models::mistral::mistral_request_message::MistralRequestBuilder; +use hiramu::bedrock::model_info::{ModelInfo, ModelName}; + +#[tokio::main] +async fn main() { + let mistral_options = MistralOptions::new() + .profile_name("bedrock") + .region("us-west-2"); + + let client = MistralClient::new(mistral_options).await; + + let request = MistralRequestBuilder::new("[INST] What is the capital of France?[/INST]".to_string()) + .max_tokens(200) + .temperature(0.8) + .build(); + + let model_id = ModelInfo::from_model_name(ModelName::MistralMixtral8X7BInstruct0x); + let mut stream = client.generate_with_stream(model_id, &request).await.unwrap(); + + while let Some(result) = stream.next().await { + match result { + Ok(response) => { + println!("Response: {:?}", response.outputs.text); + } + Err(err) => { + eprintln!("Error: {:?}", err); + } + } + } +} +``` ### Generating Text with Ollama @@ -63,6 +119,7 @@ async fn main() { ```rust use hiramu::bedrock::models::claude::claude_client::{ClaudeClient, ClaudeOptions}; use hiramu::bedrock::models::claude::claude_request_message::{ChatOptions, ConversationRequest, Message}; +use hiramu::bedrock::model_info::{ModelInfo, ModelName}; #[tokio::main] async fn main() { @@ -79,7 +136,8 @@ async fn main() { let chat_options = ChatOptions::default() .with_temperature(0.7) - .with_max_tokens(100); + .with_max_tokens(100) + .with_model_id(ModelInfo::from_model_name(ModelName::AnthropicClaudeHaiku1x)); let response_stream = client .chat_with_stream(&conversation_request, &chat_options) @@ -231,32 +289,21 @@ async fn main() { } ``` -Here's a paragraph explaining how to use Embeddings in the Ollama Rust library for a README.md file: - -## Embeddings - -The Ollama library provides functionality to generate embeddings for a given text prompt. Embeddings are dense vector representations of text that capture semantic meaning and can be used for various downstream tasks such as semantic search, clustering, and classification. To generate embeddings, you can use the `OllamaClient::embeddings` method. First, create an instance of `EmbeddingsRequestBuilder` by providing the model name and the text prompt. Optionally, you can specify additional options and a keep-alive duration. Then, call the `build` method to create an `EmbeddingsRequest` and pass it to the `embeddings` method of the `OllamaClient`. The method returns an `EmbeddingsResponse` containing the generated embedding as a vector of floating-point values. Here's an example: - -```rust -use ollama::{OllamaClient, EmbeddingsRequestBuilder}; - -let client = OllamaClient::new("http://localhost:11434".to_string()); -let request = EmbeddingsRequestBuilder::new( - "nomic-embed-text".to_string(), - "Here is an article about llamas...".to_string(), -) -.options(serde_json::json!({ "temperature": 0.8 })) -.keep_alive("10m".to_string()) -.build(); - -let response = client.embeddings(request).await.unwrap(); -println!("Embeddings: {:?}", response.embedding); -``` - -This code snippet demonstrates how to create an `EmbeddingsRequestBuilder`, set the model name, prompt, options, and keep-alive duration, and then build the request. The `embeddings` method is called with the request, and the resulting `EmbeddingsResponse` contains the generated embedding. - +## Examples +Here is a table with a description for each example: +| Example | Path | Description | +|---------|------|--------------| +| `demo_ollama` | [src/examples/demo_ollama.rs](src/examples/demo_ollama.rs) | A simple example that demonstrates how to use the Ollama API to generate responses to chat messages. | +| `demo_bedrock_raw_generate` | [src/examples/demo_bedrock_raw_generate.rs](src/examples/demo_bedrock_raw_generate.rs) | Demonstrates how to generate a raw response from the Bedrock service using the `generate_raw` method. | +| `demo_bedrock_raw_stream` | [src/examples/demo_bedrock_raw_stream.rs](src/examples/demo_bedrock_raw_stream.rs) | Demonstrates how to generate a raw stream of responses from the Bedrock service using the `generate_raw_stream` method. | +| `demo_bedrock_raw_mistral` | [src/examples/demo_bedrock_raw_mistral.rs](src/examples/demo_bedrock_raw_mistral.rs) | Demonstrates how to generate a raw stream of responses from the Mistral model in the Bedrock service. | +| `demo_claude_chat` | [src/examples/demo_claude_chat.rs](src/examples/demo_claude_chat.rs) | Demonstrates how to use the Claude model in the Bedrock service to generate a chat response. | +| `demo_claude_chat_stream` | [src/examples/demo_claude_chat_stream.rs](src/examples/demo_claude_chat_stream.rs) | Demonstrates how to use the Claude model in the Bedrock service to generate a stream of chat responses. | +| `demo_claude_multimedia` | [src/examples/demo_claude_multimedia.rs](src/examples/demo_claude_multimedia.rs) | Demonstrates how to use the Claude model in the Bedrock service to generate a response based on text and an image. | +| `demo_ollama_embedding` | [src/examples/demo_ollama_embedding.rs](src/examples/demo_ollama_embedding.rs) | Demonstrates how to use the Ollama API to generate text embeddings. | +| `demo_mistral_stream` | [src/examples/demo_mistral_stream.rs](src/examples/demo_mistral_stream.rs) | Demonstrates how to use the Mistral model in the Bedrock service to generate a stream of responses. ## Contributing @@ -272,13 +319,14 @@ To contribute to the project, follow these steps: ## License -Hiramu is licensed under the [MIT License]. +Hiramu is licensed under the [MIT License](./LICENCE). ## Acknowledgements Hiramu is built on top of the following libraries and APIs: - [Ollama](https://ollama.com/) +- [Bedrock](https://bedrock.com/) - [reqwest](https://docs.rs/reqwest) - [tokio](https://tokio.rs/) - [serde](https://serde.rs/) diff --git a/REVISION.md b/REVISION.md new file mode 100644 index 0000000..1699526 --- /dev/null +++ b/REVISION.md @@ -0,0 +1,6 @@ +# Revision History + +## [2024-04-05] + +### 0.1.7 +- Added support for the Mistral API in the Bedrock module. diff --git a/documentation/TODO.Md b/documentation/TODO.Md index c5812db..a3fbc3b 100644 --- a/documentation/TODO.Md +++ b/documentation/TODO.Md @@ -2,7 +2,7 @@ [X] - Implement Bedrock Client [X] - Implement Bedrock MultiModal in Chat, Image support -[ ] - Add support to embedding models with Ollama +[X] - Add support to embedding models with Ollama [ ] - Add support to embedding models with Bedrocks [ ] - Add more Tests and examples [ ] - Expose the Library for Python / NodeJs diff --git a/src/bedrock/model_info.rs b/src/bedrock/model_info.rs index f7c15e9..46aa023 100644 --- a/src/bedrock/model_info.rs +++ b/src/bedrock/model_info.rs @@ -19,6 +19,7 @@ pub enum ModelName { MetaLlama2Chat70B1x, MistralMistral7BInstruct0x, MistralMixtral8X7BInstruct0x, + MistralLarge, StabilityStableDiffusionXL0x, StabilityStableDiffusionXL1x, } @@ -107,6 +108,10 @@ impl ModelInfo { name: ModelName::MistralMixtral8X7BInstruct0x, text: "mistral.mixtral-8x7b-instruct-v0:1", }, + ModelInfo { + name: ModelName::MistralLarge, + text: "mistral.mistral-large-2402-v1:0", + }, ModelInfo { name: ModelName::StabilityStableDiffusionXL0x, text: "stability.stable-diffusion-xl-v0", diff --git a/src/bedrock/models/mistral/error.rs b/src/bedrock/models/mistral/error.rs new file mode 100644 index 0000000..e93465a --- /dev/null +++ b/src/bedrock/models/mistral/error.rs @@ -0,0 +1,27 @@ +use crate::bedrock::error::BedrockError; +use serde_json::Error as SerdeJsonError; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum MistralError { + #[error("HTTP error: {0}")] + Http(#[from] reqwest::Error), + + #[error("JSON error: {0}")] + Json(#[from] SerdeJsonError), + + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + #[error("UTF-8 error: {0}")] + Utf8(#[from] std::str::Utf8Error), + + #[error("Invalid response: {0}")] + InvalidResponse(String), + + #[error("Unknown error: {0}")] + Unknown(String), + + #[error("Bedrock error: {0}")] + Bedrock(#[from] BedrockError), +} \ No newline at end of file diff --git a/src/bedrock/models/mistral/mistral_client.rs b/src/bedrock/models/mistral/mistral_client.rs new file mode 100644 index 0000000..f2175de --- /dev/null +++ b/src/bedrock/models/mistral/mistral_client.rs @@ -0,0 +1,125 @@ +use crate::bedrock::bedrock_client::{BedrockClient, BedrockClientOptions}; +use crate::bedrock::models::mistral::error::MistralError; +use crate::bedrock::models::mistral::mistral_request_message::{MistralRequest, MistralResponse}; +use futures::stream::Stream; +use futures::TryStreamExt; + +pub type MistralOptions = BedrockClientOptions; + +pub struct MistralClient { + client: BedrockClient, +} + +impl MistralClient { + /// Constructs a new `MistralClient`. + pub async fn new(options: MistralOptions) -> Self { + Self { + client: BedrockClient::new(options).await, + } + } + + /// Generates a response from the Mistral model. + pub async fn generate( + &self, + model_id: String, + request: &MistralRequest, + ) -> Result { + let payload = serde_json::to_value(request).map_err(MistralError::Json)?; + + let response = self.client.generate_raw(model_id, payload).await?; + + let mistral_response = serde_json::from_value(response).map_err(MistralError::Json)?; + Ok(mistral_response) + } + + /// Generates a stream of responses from the Mistral model. + pub async fn generate_with_stream( + &self, + model_id: String, + request: &MistralRequest, + ) -> Result>, MistralError> { + let payload = serde_json::to_value(request).map_err(MistralError::Json)?; + + let response = self.client.generate_raw_stream(model_id, payload).await?; + + + Ok(response + .map_ok(|value| serde_json::from_value(value).map_err(MistralError::Json)) + .map_err(|err| MistralError::Bedrock(err)) + .and_then(futures::future::ready)) + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use crate::bedrock::{models::mistral::mistral_request_message::MistralRequestBuilder, ModelInfo}; + use futures::stream::StreamExt; + + #[tokio::test] + async fn test_generate() { + let options = MistralOptions::new().profile_name("bedrock").region("us-west-2"); + let client = MistralClient::new(options).await; + + let request = MistralRequestBuilder::new("[INST] What is the capital of France ?[/INST]".to_string()) + .max_tokens(200) + .temperature(0.8) + .build(); + + let model_name = ModelInfo::from_model_name(crate::bedrock::ModelName::MistralMixtral8X7BInstruct0x); + + let response = client.generate(model_name, &request).await; + + let response = match response { + Ok(response) => response, + Err(err) => panic!("Error: {:?}", err), + }; + + println!("Response: {:?}", response.outputs[0].text.to_string()); + + assert!(!response.outputs.is_empty()); + } + + #[tokio::test] + async fn test_generate_with_stream() { + let options = MistralOptions::new().profile_name("bedrock").region("us-west-2"); + let client = MistralClient::new(options).await; + + let request = MistralRequestBuilder::new("[INST] What is the capital of France ?[/INST]".to_string()) + .max_tokens(200) + .temperature(0.8) + .build(); + + let model_name = ModelInfo::from_model_name(crate::bedrock::ModelName::MistralMixtral8X7BInstruct0x); + + // display the request as a pretty-printed JSON string + let display_request = serde_json::to_string_pretty(&request).unwrap(); + println!("Request: {}", display_request); + + + + let mut stream = client + .generate_with_stream("mistral.mistral-7b-instruct-v0:2".to_string(), &request) + .await + .unwrap(); + + let mut response_text = String::new(); + while let Some(result) = stream.next().await { + match result { + Ok(response) => { + println!("Response: {:?}", response.outputs[0].text.to_string()); + response_text.push_str(&response.outputs[0].text); + } + Err(err) => { + panic!("Error: {:?}", err); + } + } + } + + assert!(!response_text.is_empty()); + + } + + +} \ No newline at end of file diff --git a/src/bedrock/models/mistral/mistral_request_message.rs b/src/bedrock/models/mistral/mistral_request_message.rs new file mode 100644 index 0000000..1e77616 --- /dev/null +++ b/src/bedrock/models/mistral/mistral_request_message.rs @@ -0,0 +1,149 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct MistralRequest { + pub prompt: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, +} + + + +#[derive(Debug, Deserialize,Serialize)] +pub struct MistralResponse { + pub outputs: Vec, +} + +#[derive(Debug, Deserialize,Serialize)] +pub struct MistralOutput { + pub text: String, + pub stop_reason: Option, +} + +pub struct MistralRequestBuilder { + prompt: String, + max_tokens: Option, + temperature: Option, + top_p: Option, + top_k: Option, + stop: Option>, +} + +impl MistralRequestBuilder { + pub fn new(prompt: String) -> Self { + Self { + prompt, + max_tokens: None, + temperature: None, + top_p: None, + top_k: None, + stop: None, + } + } + + pub fn max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = Some(max_tokens); + self + } + + pub fn temperature(mut self, temperature: f32) -> Self { + self.temperature = Some(temperature); + self + } + + pub fn top_p(mut self, top_p: f32) -> Self { + self.top_p = Some(top_p); + self + } + + pub fn top_k(mut self, top_k: u32) -> Self { + self.top_k = Some(top_k); + self + } + + pub fn stop(mut self, stop: Vec) -> Self { + self.stop = Some(stop); + self + } + + pub fn build(self) -> MistralRequest { + MistralRequest { + prompt: self.prompt, + max_tokens: self.max_tokens, + temperature: self.temperature, + top_p: self.top_p, + top_k: self.top_k, + stop: self.stop, + } + } +} + +pub struct MistralOptionsBuilder { + max_tokens: Option, + temperature: Option, + top_p: Option, + top_k: Option, + stop: Option>, +} + +impl Default for MistralOptionsBuilder { + fn default() -> Self { + Self { + max_tokens: Some(400), + temperature: Some(0.7), + top_p: Some(0.7), + top_k: Some(50), + stop: None, + } + } +} + +impl MistralOptionsBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = Some(max_tokens); + self + } + + pub fn temperature(mut self, temperature: f32) -> Self { + self.temperature = Some(temperature); + self + } + + pub fn top_p(mut self, top_p: f32) -> Self { + self.top_p = Some(top_p); + self + } + + pub fn top_k(mut self, top_k: u32) -> Self { + self.top_k = Some(top_k); + self + } + + pub fn stop(mut self, stop: Vec) -> Self { + self.stop = Some(stop); + self + } + + pub fn build(self) -> MistralRequest { + MistralRequest { + prompt: String::new(), + max_tokens: self.max_tokens, + temperature: self.temperature, + top_p: self.top_p, + top_k: self.top_k, + stop: self.stop, + } + } +} \ No newline at end of file diff --git a/src/bedrock/models/mistral/mod.rs b/src/bedrock/models/mistral/mod.rs new file mode 100644 index 0000000..a8b7566 --- /dev/null +++ b/src/bedrock/models/mistral/mod.rs @@ -0,0 +1,12 @@ +pub mod error; +pub mod mistral_client; +pub mod mistral_request_message; + + +pub use mistral_client::MistralClient; +pub use mistral_client::MistralOptions; +pub use error::MistralError; +pub use mistral_request_message::MistralRequest; +pub use mistral_request_message::MistralResponse; +pub use mistral_request_message::MistralOptionsBuilder; +pub use mistral_request_message::MistralRequestBuilder; diff --git a/src/bedrock/models/mod.rs b/src/bedrock/models/mod.rs index e098b0f..ece9d94 100644 --- a/src/bedrock/models/mod.rs +++ b/src/bedrock/models/mod.rs @@ -1 +1,2 @@ pub mod claude; +pub mod mistral; \ No newline at end of file diff --git a/src/examples/demo_bedrock_raw_mistral.rs b/src/examples/demo_bedrock_raw_mistral.rs new file mode 100644 index 0000000..52873b6 --- /dev/null +++ b/src/examples/demo_bedrock_raw_mistral.rs @@ -0,0 +1,85 @@ +use futures::TryStreamExt; +use std::io; +use std::io::Write; + +use crate::bedrock::bedrock_client::{BedrockClient, BedrockClientOptions}; + + +pub async fn demo_bedrock_mistral_raw_stream(model_id: &str, prompt: &str) { + let profile_name = "bedrock"; + let region = "us-west-2"; + + + let payload = serde_json::json!({ + "prompt": prompt, + "max_tokens" : 200, + "stop" : ["[INST]"], + "temperature": 0.5, + "top_p": 0.9, + "top_k": 100, + }); + + let options = BedrockClientOptions::new() + .profile_name(profile_name) + .region(region); + + + let client = BedrockClient::new(options).await; + + let stream = client + .generate_raw_stream( + model_id.to_string(), + payload, + ) + .await; + + let stream = match stream { + Ok(stream) => stream, + Err(err) => { + println!("Error: {:?}", err); + return; + } + }; + + // consumme the stream and print the response + stream + .try_for_each(|chunk| async move { + println!("{:?}", chunk); + // Flush the output to ensure the prompt is displayed. + io::stdout().flush().unwrap(); + Ok(()) + }) + .await + .unwrap(); +} + +// Write a test + +#[cfg(test)] +mod tests { + use super::*; + use crate::bedrock::model_info::{ModelInfo, ModelName}; + + #[tokio::test] + async fn test_demo_bedrock_mistral_raw_stream_8x7() { + let model_id = ModelInfo::from_model_name(ModelName::MistralMixtral8X7BInstruct0x); + let prompt = "[INST] What is your favourite condiment? [/INST]"; + demo_bedrock_mistral_raw_stream(&model_id,&prompt).await; + } + + #[tokio::test] + async fn test_demo_bedrock_mistral_raw_stream_7b() { + let model_id = ModelInfo::from_model_name(ModelName::MistralMistral7BInstruct0x); + let prompt = "[INST] What is your favourite condiment? [/INST]"; + demo_bedrock_mistral_raw_stream(&model_id,&prompt).await; + } + + #[tokio::test] + async fn test_demo_bedrock_mistral_raw_stream_large() { + let model_id = ModelInfo::from_model_name(ModelName::MistralLarge); + let prompt = "[INST] What is your favourite condiment? [/INST]"; + demo_bedrock_mistral_raw_stream(&model_id,&prompt).await; + } + + +} diff --git a/src/examples/demo_mistral_stream.rs b/src/examples/demo_mistral_stream.rs new file mode 100644 index 0000000..f72c21f --- /dev/null +++ b/src/examples/demo_mistral_stream.rs @@ -0,0 +1,65 @@ +use futures::TryStreamExt; + +use crate::bedrock::model_info::{ModelInfo, ModelName}; +use crate::bedrock::models::mistral::MistralClient; +use crate::bedrock::models::mistral::MistralOptions; +use crate::bedrock::models::mistral::MistralRequestBuilder; + + +pub async fn demo_mistra_with_stream(model_id: &str, prompt: &str) { + + let mistral_otions + = MistralOptions::new() + .profile_name("bedrock") + .region("us-west-2"); + + let client = MistralClient::new(mistral_otions).await; + + + + let request = MistralRequestBuilder::new(prompt.to_owned()) + .max_tokens(200) + .temperature(0.5) + .top_p(0.9) + .top_k(100) + .build(); + + let response_stream = client + .generate_with_stream( + model_id.to_string(), + &request + ) + .await; + + let response_stream = match response_stream { + Ok(response_stream) => response_stream, + Err(e) => { + println!("Error: {:?}", e); + return; + } + }; + + // consumme the stream and print the response + response_stream + .try_for_each(|chunk| async move { + let json_display = serde_json::to_string_pretty(&chunk).unwrap(); + println!("{:?}", json_display); + Ok(()) + }) + .await + .unwrap(); + +} + +// Test +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_demo_chat_mistral_with_stream() { + let model_id = ModelInfo::from_model_name(ModelName::MistralMixtral8X7BInstruct0x); + let prompt = "[INST] What is the capital of France ?[/INST]"; + demo_mistra_with_stream(&model_id, prompt).await; + } +} diff --git a/src/examples/mod.rs b/src/examples/mod.rs index 3aef67a..554b4fb 100644 --- a/src/examples/mod.rs +++ b/src/examples/mod.rs @@ -1,7 +1,9 @@ pub mod demo_ollama; pub mod demo_bedrock_raw_generate; pub mod demo_bedrock_raw_stream; +pub mod demo_bedrock_raw_mistral; pub mod demo_claude_chat; pub mod demo_claude_chat_stream; pub mod demo_claude_multimedia; -pub mod demo_ollama_embedding; \ No newline at end of file +pub mod demo_ollama_embedding; +pub mod demo_mistral_stream; diff --git a/src/lib.rs b/src/lib.rs index 884ea0d..409986b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,275 +1,11 @@ //! # Hiramu //! -//! Hiramu is a powerful and flexible Rust library that provides a high-level interface for interacting with various AI models and APIs, including Ollama and Bedrock. It simplifies the process of generating text, engaging in chat conversations, and working with different AI models. -//! -//! ## Features -//! -//! - Easy-to-use interfaces for generating text and engaging in chat conversations with AI models -//! - Support for Ollama and Bedrock AI services -//! - Asynchronous and streaming responses for efficient handling of large outputs -//! - Customizable options for fine-tuning the behavior of AI models -//! - Comprehensive error handling and informative error messages -//! - Well-documented code with examples and explanations -//! -//! ## Getting Started -//! -//! To start using Hiramu in your Rust project, add the following to your `Cargo.toml` file: -//! -//! ```toml -//! [dependencies] -//! hiramu = "0.1.X" -//! ``` -//! -//! Then, import the necessary modules and types in your Rust code: -//! -//! ```rust -//! use hiramu::ollama::ollama_client::OllamaClient; -//! use hiramu::ollama::model::{GenerateRequest, GenerateRequestBuilder, GenerateResponse}; -//! use hiramu::bedrock::bedrock_client::{BedrockClient, BedrockClientOptions}; -//! use hiramu::bedrock::models::claude::claude_client::{ClaudeClient, ClaudeOptions}; -//! use hiramu::bedrock::models::claude::claude_request_message::{ChatOptions, ConversationRequest, Message}; -//! ``` -//! -//! ## Examples -//! -//! ### Generating Text with Ollama -//! -//! ```rust -//! use hiramu::ollama::ollama_client::OllamaClient; -//! use hiramu::ollama::model::{GenerateRequest, GenerateRequestBuilder}; -//! -//! #[tokio::main] -//! async fn main() { -//! let client = OllamaClient::new("http://localhost:11434".to_string()); -//! let request = GenerateRequestBuilder::new("mistral".to_string()) -//! .prompt("Once upon a time".to_string()) -//! .build(); -//! -//! let response_stream = client.generate(request).await.unwrap(); -//! -//! response_stream -//! .try_for_each(|chunk| async move { -//! println!("{}", chunk.response); -//! Ok(()) -//! }) -//! .await -//! .unwrap(); -//! } -//! ``` -//! -//! ### Chatting with Claude using Bedrock -//! -//! ```rust -//! use hiramu::bedrock::models::claude::claude_client::{ClaudeClient, ClaudeOptions}; -//! use hiramu::bedrock::models::claude::claude_request_message::{ChatOptions, ConversationRequest, Message}; -//! -//! #[tokio::main] -//! async fn main() { -//! let claude_options = ClaudeOptions::new() -//! .profile_name("bedrock") -//! .region("us-west-2"); -//! -//! let client = ClaudeClient::new(claude_options).await; -//! -//! let mut conversation_request = ConversationRequest::default(); -//! conversation_request -//! .messages -//! .push(Message::new_user_message("Hello, Claude!")); -//! -//! let chat_options = ChatOptions::default() -//! .with_temperature(0.7) -//! .with_max_tokens(100); -//! -//! let response_stream = client -//! .chat_with_stream(&conversation_request, &chat_options) -//! .await -//! .unwrap(); -//! -//! response_stream -//! .try_for_each(|chunk| async move { -//! println!("{:?}", chunk); -//! Ok(()) -//! }) -//! .await -//! .unwrap(); -//! } -//! ``` -//! -//! ### Sending Images with Claude -//! -//! ```rust -//! use hiramu::bedrock::models::claude::claude_client::{ClaudeClient, ClaudeOptions}; -//! use hiramu::bedrock::models::claude::claude_request_message::{ChatOptions, ConversationRequest, Message}; -//! use hiramu::fetch_and_base64_encode_image; -//! -//! #[tokio::main] -//! async fn main() { -//! let claude_options = ClaudeOptions::new() -//! .profile_name("bedrock") -//! .region("us-west-2"); -//! -//! let client = ClaudeClient::new(claude_options).await; -//! -//! let image_url = "./data/mario.png"; -//! let input_text = "What's in this image?".to_string(); -//! let image = fetch_and_base64_encode_image(image_url).await.unwrap().to_string(); -//! let mime_type = "image/png".to_string(); -//! -//! let message = Message::new_user_message_with_image(&input_text, &image, &mime_type); -//! -//! let mut conversation_request = ConversationRequest::default(); -//! conversation_request.messages.push(message); -//! -//! let chat_options = ChatOptions::default() -//! .with_temperature(0.7) -//! .with_max_tokens(100); -//! -//! let response_stream = client -//! .chat_with_stream(&conversation_request, &chat_options) -//! .await -//! .unwrap(); -//! -//! response_stream -//! .try_for_each(|chunk| async move { -//! println!("{:?}", chunk); -//! Ok(()) -//! }) -//! .await -//! .unwrap(); -//! } -//! ``` -//! -//! ### Using the Raw Bedrock API -//! -//! #### Generating a Raw Response -//! -//! ```rust -//! use hiramu::bedrock::bedrock_client::{BedrockClient, BedrockClientOptions}; -//! use hiramu::bedrock::model_info::{ModelInfo, ModelName}; -//! -//! #[tokio::main] -//! async fn main() { -//! let model_id = ModelInfo::from_model_name(ModelName::AnthropicClaudeHaiku1x); -//! let profile_name = "bedrock"; -//! let region = "us-west-2"; -//! -//! let prompt = "Hi. In a short paragraph, explain what you can do."; -//! -//! let payload = serde_json::json!({ -//! "anthropic_version": "bedrock-2023-05-31", -//! "max_tokens": 1000, -//! "messages": [{ -//! "role": "user", -//! "content": [{ -//! "type": "text", -//! "text": prompt -//! }] -//! }] -//! }); -//! -//! let options = BedrockClientOptions::new() -//! .profile_name(profile_name) -//! .region(region); -//! -//! let client = BedrockClient::new(options).await; -//! -//! let result = client -//! .generate_raw(model_id.to_string(), payload) -//! .await -//! .unwrap(); -//! -//! println!("{:?}", result); -//! } -//! ``` -//! -//! #### Generating a Raw Stream Response -//! -//! ```rust -//! use futures::TryStreamExt; -//! use hiramu::bedrock::bedrock_client::{BedrockClient, BedrockClientOptions}; -//! use hiramu::bedrock::model_info::{ModelInfo, ModelName}; -//! -//! #[tokio::main] -//! async fn main() { -//! let model_id = ModelInfo::from_model_name(ModelName::AnthropicClaudeHaiku1x); -//! let profile_name = "bedrock"; -//! let region = "us-west-2"; -//! -//! let prompt = "Hi. In a short paragraph, explain what you can do."; -//! -//! let payload = serde_json::json!({ -//! "anthropic_version": "bedrock-2023-05-31", -//! "max_tokens": 1000, -//! "messages": [{ -//! "role": "user", -//! "content": [{ -//! "type": "text", -//! "text": prompt -//! }] -//! }] -//! }); -//! -//! let options = BedrockClientOptions::new() -//! .profile_name(profile_name) -//! .region(region); -//! -//! let client = BedrockClient::new(options).await; -//! -//! let stream = client -//! .generate_raw_stream(model_id.to_string(), payload) -//! .await -//! .unwrap(); -//! -//! stream -//! .try_for_each(|chunk| async move { -//! println!("{:?}", chunk); -//! Ok(()) -//! }) -//! .await -//! .unwrap(); -//! } -//! ``` -//! -//! ## Error Handling -//! -//! Hiramu provides comprehensive error handling through the `HiramuError` enum, which covers various error scenarios such as HTTP errors, JSON parsing errors, I/O errors, and more. Each error variant provides detailed information about the cause of the error, making it easier to diagnose and handle issues. -//! -//! When an error occurs, the corresponding variant of `HiramuError` is returned, allowing you to match on the error and take appropriate action. Hiramu also integrates with the `thiserror` crate, providing convenient error propagation and formatting. -//! -//! ## Contributing -//! -//! Contributions to Hiramu are welcome! If you encounter any issues, have suggestions for improvements, or want to add new features, please open an issue or submit a pull request on the [GitHub repository](https://github.com/raphaelmansuy/hiramu). -//! -//! To contribute to the project, follow these steps: -//! -//! 1. Fork the repository and create a new branch for your changes. -//! 2. Make your modifications and ensure that the code compiles successfully. -//! 3. Write tests to cover your changes and ensure that all existing tests pass. -//! 4. Update the documentation, including the README and API docs, if necessary. -//! 5. Submit a pull request with a clear description of your changes and the problem they solve. -//! -//! ## License -//! -//! Hiramu is licensed under the [MIT License](LICENSE). -//! -//! ## Acknowledgements -//! -//! Hiramu is built on top of the following libraries and APIs: -//! -//! - [Ollama](https://ollama.com/) -//! - [Bedrock](https://bedrock.com/) -//! - [reqwest](https://docs.rs/reqwest) -//! - [tokio](https://tokio.rs/) -//! - [serde](https://serde.rs/) -//! -//! We would like to express our gratitude to the developers and maintainers of these projects for their excellent work and contributions to the Rust ecosystem. -//! +#![doc = include_str!("../README.md")] pub mod ollama; pub mod bedrock; pub mod error; pub mod util; - pub mod examples; pub use error::HiramuError;