Skip to content

Commit

Permalink
Merge pull request #95 from mistralai/release/v0.3.0
Browse files Browse the repository at this point in the history
release 0.3.0: add support for completion
  • Loading branch information
jean-malo authored May 29, 2024
2 parents 32ec8b6 + aa010d2 commit 9f6e920
Show file tree
Hide file tree
Showing 12 changed files with 442 additions and 27 deletions.
33 changes: 33 additions & 0 deletions examples/async_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python

import asyncio
import os

from mistralai.async_client import MistralAsyncClient


async def main():
api_key = os.environ["MISTRAL_API_KEY"]

client = MistralAsyncClient(api_key=api_key)

prompt = "def fibonacci(n: int):"
suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))"

response = await client.completion(
model="codestral-latest",
prompt=prompt,
suffix=suffix,
)

print(
f"""
{prompt}
{response.choices[0].message.content}
{suffix}
"""
)


if __name__ == "__main__":
asyncio.run(main())
9 changes: 5 additions & 4 deletions examples/chatbot_with_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from mistralai.models.chat_completion import ChatMessage

MODEL_LIST = [
"mistral-tiny",
"mistral-small",
"mistral-medium",
"mistral-small-latest",
"mistral-medium-latest",
"mistral-large-latest",
"codestral-latest",
]
DEFAULT_MODEL = "mistral-small"
DEFAULT_MODEL = "mistral-small-latest"
DEFAULT_TEMPERATURE = 0.7
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
# A dictionary of all commands and their arguments, used for tab completion.
Expand Down
33 changes: 33 additions & 0 deletions examples/code_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python

import asyncio
import os

from mistralai.client import MistralClient


async def main():
api_key = os.environ["MISTRAL_API_KEY"]

client = MistralClient(api_key=api_key)

prompt = "def fibonacci(n: int):"
suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))"

response = client.completion(
model="codestral-latest",
prompt=prompt,
suffix=suffix,
)

print(
f"""
{prompt}
{response.choices[0].message.content}
{suffix}
"""
)


if __name__ == "__main__":
asyncio.run(main())
29 changes: 29 additions & 0 deletions examples/completion_with_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python

import asyncio
import os

from mistralai.client import MistralClient


async def main():
api_key = os.environ["MISTRAL_API_KEY"]

client = MistralClient(api_key=api_key)

prompt = "def fibonacci(n: int):"
suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))"

print(prompt)
for chunk in client.completion_stream(
model="codestral-latest",
prompt=prompt,
suffix=suffix,
):
if chunk.choices[0].delta.content is not None:
print(chunk.choices[0].delta.content, end="")
print(suffix)


if __name__ == "__main__":
asyncio.run(main())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mistralai"
version = "0.2.0"
version = "0.3.0"
description = ""
authors = ["Bam4d <bam4d@mistral.ai>"]
readme = "README.md"
Expand Down
73 changes: 72 additions & 1 deletion src/mistralai/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def _check_response(self, response: Response) -> Dict[str, Any]:
async def _request(
self,
method: str,
json: Dict[str, Any],
json: Optional[Dict[str, Any]],
path: str,
stream: bool = False,
attempt: int = 1,
Expand Down Expand Up @@ -291,3 +291,74 @@ async def list_models(self) -> ModelList:
return ModelList(**response)

raise MistralException("No response received")

async def completion(
self,
model: str,
prompt: str,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
stop: Optional[List[str]] = None,
) -> ChatCompletionResponse:
"""An asynchronous completion endpoint that returns a single response.
Args:
model (str): model the name of the model to get completions with, e.g. codestral-latest
prompt (str): the prompt to complete
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
Returns:
Dict[str, Any]: a response object containing the generated text.
"""
request = self._make_completion_request(
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
)
single_response = self._request("post", request, "v1/fim/completions")

async for response in single_response:
return ChatCompletionResponse(**response)

raise MistralException("No response received")

async def completion_stream(
self,
model: str,
prompt: str,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
stop: Optional[List[str]] = None,
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
"""An asynchronous completion endpoint that returns a streaming response.
Args:
model (str): model the name of the model to get completions with, e.g. codestral-latest
prompt (str): the prompt to complete
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
Returns:
Dict[str, Any]: a response object containing the generated text.
"""
request = self._make_completion_request(
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
)
async_response = self._request("post", request, "v1/fim/completions", stream=True)

async for json_response in async_response:
yield ChatCompletionStreamResponse(**json_response)
76 changes: 75 additions & 1 deletion src/mistralai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _check_response(self, response: Response) -> Dict[str, Any]:
def _request(
self,
method: str,
json: Dict[str, Any],
json: Optional[Dict[str, Any]],
path: str,
stream: bool = False,
attempt: int = 1,
Expand Down Expand Up @@ -285,3 +285,77 @@ def list_models(self) -> ModelList:
return ModelList(**response)

raise MistralException("No response received")

def completion(
self,
model: str,
prompt: str,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
stop: Optional[List[str]] = None,
) -> ChatCompletionResponse:
"""A completion endpoint that returns a single response.
Args:
model (str): model the name of the model to get completion with, e.g. codestral-latest
prompt (str): the prompt to complete
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
Returns:
Dict[str, Any]: a response object containing the generated text.
"""
request = self._make_completion_request(
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
)

single_response = self._request("post", request, "v1/fim/completions", stream=False)

for response in single_response:
return ChatCompletionResponse(**response)

raise MistralException("No response received")

def completion_stream(
self,
model: str,
prompt: str,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
stop: Optional[List[str]] = None,
) -> Iterable[ChatCompletionStreamResponse]:
"""An asynchronous completion endpoint that streams responses.
Args:
model (str): model the name of the model to get completions with, e.g. codestral-latest
prompt (str): the prompt to complete
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
Returns:
Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text.
"""
request = self._make_completion_request(
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
)

response = self._request("post", request, "v1/fim/completions", stream=True)

for json_streamed_response in response:
yield ChatCompletionStreamResponse(**json_streamed_response)
71 changes: 63 additions & 8 deletions src/mistralai/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,63 @@ def _parse_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:

return parsed_messages

def _make_completion_request(
self,
prompt: str,
model: Optional[str] = None,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
stop: Optional[List[str]] = None,
stream: Optional[bool] = False,
) -> Dict[str, Any]:
request_data: Dict[str, Any] = {
"prompt": prompt,
"suffix": suffix,
"model": model,
"stream": stream,
}

if stop is not None:
request_data["stop"] = stop

if model is not None:
request_data["model"] = model
else:
if self._default_model is None:
raise MistralException(message="model must be provided")
request_data["model"] = self._default_model

request_data.update(
self._build_sampling_params(
temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
)
)

self._logger.debug(f"Completion request: {request_data}")

return request_data

def _build_sampling_params(
self,
max_tokens: Optional[int],
random_seed: Optional[int],
temperature: Optional[float],
top_p: Optional[float],
) -> Dict[str, Any]:
params = {}
if temperature is not None:
params["temperature"] = temperature
if max_tokens is not None:
params["max_tokens"] = max_tokens
if top_p is not None:
params["top_p"] = top_p
if random_seed is not None:
params["random_seed"] = random_seed
return params

def _make_chat_request(
self,
messages: List[Any],
Expand All @@ -99,16 +156,14 @@ def _make_chat_request(
raise MistralException(message="model must be provided")
request_data["model"] = self._default_model

request_data.update(
self._build_sampling_params(
temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
)
)

if tools is not None:
request_data["tools"] = self._parse_tools(tools)
if temperature is not None:
request_data["temperature"] = temperature
if max_tokens is not None:
request_data["max_tokens"] = max_tokens
if top_p is not None:
request_data["top_p"] = top_p
if random_seed is not None:
request_data["random_seed"] = random_seed
if stream is not None:
request_data["stream"] = stream

Expand Down
Loading

0 comments on commit 9f6e920

Please sign in to comment.