Skip to content

Commit

Permalink
Add a message, role class
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Apr 21, 2024
1 parent ce76eb9 commit 10c4742
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 55 deletions.
18 changes: 6 additions & 12 deletions examples/python/cookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"metadata": {},
"outputs": [],
"source": [
"from mistralrs import Runner, Which, ChatCompletionRequest\n",
"from mistralrs import Runner, Which, ChatCompletionRequest, Message, Role\n",
"\n",
"runner = Runner(\n",
" which=Which.MistralGGUF(\n",
Expand All @@ -28,9 +28,7 @@
"res = runner.send_chat_completion_request(\n",
" ChatCompletionRequest(\n",
" model=\"mistral\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"Tell me a story about the Rust type system.\"}\n",
" ],\n",
" messages=[Message(Role.User, \"Tell me a story about the Rust type system.\")],\n",
" max_tokens=256,\n",
" presence_penalty=1.0,\n",
" top_p=0.1,\n",
Expand All @@ -46,7 +44,7 @@
"source": [
"Lets walk through this code.\n",
"```python\n",
"from mistralrs import Runner, Which, ChatCompletionRequest\n",
"from mistralrs import Runner, Which, ChatCompletionRequest, Message, Role\n",
"```\n",
"\n",
"This imports the requires classes for our example. The `Runner` is a class which handles loading and running the model, which are enumerated by the `Which` class.\n",
Expand All @@ -69,9 +67,7 @@
"res = runner.send_chat_completion_request(\n",
" ChatCompletionRequest(\n",
" model=\"mistral\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"Tell me a story about the Rust type system.\"}\n",
" ],\n",
" messages=[Message(Role.User, \"Tell me a story about the Rust type system.\")],\n",
" max_tokens=256,\n",
" presence_penalty=1.0,\n",
" top_p=0.1,\n",
Expand Down Expand Up @@ -176,14 +172,12 @@
"metadata": {},
"outputs": [],
"source": [
"from mistralrs import ChatCompletionRequest\n",
"from mistralrs import ChatCompletionRequest, Message, Role\n",
"\n",
"res = runner.send_chat_completion_request(\n",
" ChatCompletionRequest(\n",
" model=\"mistral\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"Tell me a story about the Rust type system.\"}\n",
" ],\n",
" messages=[Message(Role.User, \"Tell me a story about the Rust type system.\")],\n",
" max_tokens=256,\n",
" presence_penalty=1.0,\n",
" top_p=0.1,\n",
Expand Down
6 changes: 2 additions & 4 deletions examples/python/python_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mistralrs import Runner, Which, ChatCompletionRequest
from mistralrs import Runner, Which, ChatCompletionRequest, Message, Role

runner = Runner(
which=Which.MistralGGUF(
Expand All @@ -13,9 +13,7 @@
res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="mistral",
messages=[
{"role": "user", "content": "Tell me a story about the Rust type system."}
],
messages=[Message(Role.User, "Tell me a story about the Rust type system.")],
max_tokens=256,
presence_penalty=1.0,
top_p=0.1,
Expand Down
6 changes: 2 additions & 4 deletions examples/python/streaming.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mistralrs import Runner, Which, ChatCompletionRequest
from mistralrs import Runner, Which, ChatCompletionRequest, Message, Role

runner = Runner(
which=Which.MistralGGUF(
Expand All @@ -13,9 +13,7 @@
res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="mistral",
messages=[
{"role": "user", "content": "Tell me a story about the Rust type system."}
],
messages=[Message(Role.User, "Tell me a story about the Rust type system.")],
max_tokens=256,
presence_penalty=1.0,
top_p=0.1,
Expand Down
4 changes: 2 additions & 2 deletions examples/python/xlora_gemma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mistralrs import Runner, Which, ChatCompletionRequest
from mistralrs import Runner, Which, ChatCompletionRequest, Message, Role

runner = Runner(
which=Which.XLoraGemma(
Expand All @@ -14,7 +14,7 @@
res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="mistral",
messages=[{"role": "user", "content": "What is graphene?"}],
messages=[Message(Role.User, "Tell me a story about the Rust type system.")],
max_tokens=256,
presence_penalty=1.0,
top_p=0.1,
Expand Down
4 changes: 2 additions & 2 deletions examples/python/xlora_zephyr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mistralrs import Runner, Which, ChatCompletionRequest
from mistralrs import Runner, Which, ChatCompletionRequest, Message, Role

runner = Runner(
which=Which.XLoraMistralGGUF(
Expand All @@ -16,7 +16,7 @@
res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="mistral",
messages=[{"role": "user", "content": "What is graphene?"}],
messages=[Message(Role.User, "Tell me a story about the Rust type system.")],
max_tokens=256,
presence_penalty=1.0,
top_p=0.1,
Expand Down
6 changes: 2 additions & 4 deletions mistralrs-pyo3/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Request is a class with a constructor which accepts the following arguments. It

## Example
```python
from mistralrs import Runner, Which, ChatCompletionRequest
from mistralrs import Runner, Which, ChatCompletionRequest, Message, Role

runner = Runner(
which=Which.MistralGGUF(
Expand All @@ -69,9 +69,7 @@ runner = Runner(
res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="mistral",
messages=[
{"role": "user", "content": "Tell me a story about the Rust type system."}
],
messages=[Message(Role.User, "Tell me a story about the Rust type system.")],
max_tokens=256,
presence_penalty=1.0,
top_p=0.1,
Expand Down
16 changes: 12 additions & 4 deletions mistralrs-pyo3/mistralrs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class ChatCompletionRequest:
about input data, sampling, and how to return the response.
"""

messages: list[dict[str, str]] | str
messages: list[Message] | str
model: str
logit_bias: dict[int, float] | None = None
logprobs: bool = False
Expand Down Expand Up @@ -66,13 +66,13 @@ class _Quantized(_Base):
quantized_filename: str

@dataclass
class _XLoraQuantized(_Base, _Quantized):
class _XLoraQuantized(_Quantized):
xlora_model_id: str
order: str
tgt_non_granular_index: int | None

@dataclass
class _XLoraNormal(_Base, _Normal):
class _XLoraNormal(_Normal):
xlora_model_id: str
order: str
tgt_non_granular_index: int | None
Expand Down Expand Up @@ -138,7 +138,7 @@ class Runner:
prefix_cache_n: int = 16,
token_source="cache",
chat_template=None,
) -> Runner:
) -> None:
"""
Load a model.
Expand Down Expand Up @@ -166,3 +166,11 @@ class Runner:
"""
Send a chat completion request to the mistral.rs engine, returning the response object.
"""

class Role(Enum):
User = 1
Assistant = 2

class Message:
role: Role
content: str
46 changes: 23 additions & 23 deletions mistralrs-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use candle_core::Result;
use either::Either;
use indexmap::IndexMap;
use message::{Message, Role};
use std::{
cell::RefCell,
collections::HashMap,
Expand All @@ -23,12 +24,13 @@ use mistralrs_core::{
use pyo3::{
exceptions::{PyTypeError, PyValueError},
prelude::*,
types::{PyDict, PyList, PyString},
types::{PyList, PyString},
};
use std::fs::File;
mod stream;
mod which;
use which::Which;
mod message;

#[cfg(not(feature = "metal"))]
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
Expand Down Expand Up @@ -882,7 +884,20 @@ impl Runner {
last_v
},
messages: match request.messages {
Either::Left(ref messages) => RequestMessage::Chat(messages.clone()),
Either::Left(ref messages) => {
let mut messages_vec = Vec::new();
for message in messages {
let mut message_map = IndexMap::new();
let role = match message.role {
Role::Assistant => "assistant",
Role::User => "user",
};
message_map.insert("role".to_string(), role.to_string());
message_map.insert("content".to_string(), message.content.clone());
messages_vec.push(message_map);
}
RequestMessage::Chat(messages_vec)
}
Either::Right(ref prompt) => {
let mut messages = Vec::new();
let mut message_map = IndexMap::new();
Expand Down Expand Up @@ -1106,7 +1121,7 @@ impl CompletionRequest {
#[derive(Debug)]
/// An OpenAI API compatible chat completion request.
struct ChatCompletionRequest {
messages: Either<Vec<IndexMap<String, String>>, String>,
messages: Either<Vec<Message>, String>,
_model: String,
logit_bias: Option<HashMap<u32, f32>>,
logprobs: bool,
Expand Down Expand Up @@ -1167,29 +1182,12 @@ impl ChatCompletionRequest {
if let Ok(messages) = messages.bind(py).downcast_exact::<PyList>() {
let mut messages_vec = Vec::new();
for message in messages {
let mapping = message.downcast::<PyDict>()?.as_mapping();
let mut messages_map = IndexMap::new();
for i in 0..mapping.len()? {
let k = mapping
.keys()?
.get_item(i)?
.downcast::<PyString>()?
.extract::<String>()?;
let v = mapping
.values()?
.get_item(i)?
.downcast::<PyString>()?
.extract::<String>()?;
messages_map.insert(k, v);
}
messages_vec.push(messages_map);
messages_vec.push(message.extract::<Message>()?);
}
Ok::<Either<Vec<IndexMap<String, String>>, String>, PyErr>(Either::Left(
messages_vec,
))
Ok::<Either<Vec<Message>, String>, PyErr>(Either::Left(messages_vec))
} else if let Ok(messages) = messages.bind(py).downcast_exact::<PyString>() {
let prompt = messages.extract::<String>()?;
Ok::<Either<Vec<IndexMap<String, String>>, String>, PyErr>(Either::Right(prompt))
Ok::<Either<Vec<Message>, String>, PyErr>(Either::Right(prompt))
} else {
return Err(PyTypeError::new_err("Expected a string or list of dicts."));
}
Expand Down Expand Up @@ -1221,5 +1219,7 @@ fn mistralrs(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Which>()?;
m.add_class::<ChatCompletionRequest>()?;
m.add_class::<CompletionRequest>()?;
m.add_class::<Message>()?;
m.add_class::<Role>()?;
Ok(())
}
24 changes: 24 additions & 0 deletions mistralrs-pyo3/src/message.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use pyo3::{pyclass, pymethods};

#[pyclass]
#[derive(Clone, Debug)]
pub enum Role {
User,
Assistant,
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct Message {
pub role: Role,
pub content: String,
}

#[pymethods]
impl Message {
#[new]
#[pyo3(signature = (role, content))]
fn new(role: Role, content: String) -> Self {
Self { role, content }
}
}

0 comments on commit 10c4742

Please sign in to comment.