Skip to content

Commit

Permalink
feat: add support for mistral
Browse files Browse the repository at this point in the history
Add support for LLMs hosted by Mistral using Mistral's Python client.
  • Loading branch information
Realiserad committed Jul 4, 2024
1 parent 42098c7 commit 5655054
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 6 deletions.
1 change: 1 addition & 0 deletions .devcontainer/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ pytest
pyfakefs
iterfzf
hugchat
mistralai
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ should run on [any system with Python and git installed](https://github.com/Real
Originally based on [Tom Dörr's `fish.codex` repository](https://github.com/tom-doerr/codex.fish),
but with some additional functionality.

It can be hooked up to OpenAI, Azure OpenAI, Google, Hugging Face, or a
It can be hooked up to OpenAI, Azure OpenAI, Google, HuggingFace, Mistral or a
self-hosted LLM behind any OpenAI-compatible API.

If you like it, please add a ⭐. If you don't like it, create a PR. 😆
Expand Down Expand Up @@ -105,6 +105,17 @@ model = meta-llama/Meta-Llama-3-70B-Instruct
Available models are listed [here](https://huggingface.co/chat/models).
Note that 2FA must be disabled on the account.

If you use [Mistral](https://mistral.ai):

```ini
[fish-ai]
configuration = mistral

[mistral]
provider = mistral
api_key = <your API key>
```

### Install `fish-ai`

Install the plugin. You can install it using [`fisher`](https://github.com/jorgebucaran/fisher).
Expand Down
3 changes: 1 addition & 2 deletions bump-version.fish
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ if ! type -q what-bump
exit 1
end

set current_branch (git rev-parse --abbrev-ref HEAD)
set start_hash (git show-ref --hash refs/remotes/origin/$current_branch)
set start_hash (git show-ref --hash refs/remotes/origin/main)
set current_version (git show $start_hash:pyproject.toml | grep version | head -n 1 | cut -d'"' -f2)
set next_version (what-bump --from $current_version $start_hash)
if test "$current_version" = "$next_version"
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fish_ai"
version = "0.7.1"
version = "0.8.0"
authors = [{ name = "Bastian Fredriksson", email = "realiserad@gmail.com" }]
description = "Provides core functionality for fish-ai, an AI plugin for the fish shell."
readme = "README.md"
Expand All @@ -20,6 +20,7 @@ dependencies = [
"simple-term-menu==1.6.4",
"iterfzf==1.4.0.51.0",
"hugchat==0.4.8",
"mistralai==0.4.2",
]

[project.urls]
Expand Down
26 changes: 24 additions & 2 deletions src/fish_ai/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import sys
from hugchat import hugchat
from hugchat.login import Login
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

config = ConfigParser()
config.read(path.expanduser('~/.config/fish-ai.ini'))
Expand Down Expand Up @@ -142,7 +144,7 @@ def get_openai_client():
.format(get_config('provider')))


def create_message_history(messages):
def get_messages_for_gemini(messages):
"""
Create message history which can be used with Gemini.
Google uses a different chat history format than OpenAI.
Expand Down Expand Up @@ -170,6 +172,15 @@ def create_message_history(messages):
return outputs


def get_messages_for_mistral(messages):
output = []
for message in messages:
output.append(
ChatMessage(role=message.get('role'),
content=message.get('content')))
return output


def create_system_prompt(messages):
return '\n\n'.join(
list(
Expand All @@ -192,7 +203,7 @@ def get_response(messages):
genai.configure(api_key=get_config('api_key'))
model = genai.GenerativeModel(
get_config('model') or 'gemini-1.5-flash')
chat = model.start_chat(history=create_message_history(messages))
chat = model.start_chat(history=get_messages_for_gemini(messages))
generation_config = GenerationConfig(
candidate_count=1,
temperature=float(get_config('temperature') or '0.2'))
Expand All @@ -215,6 +226,17 @@ def get_response(messages):

response = bot.chat(messages[-1].get('content')).wait_until_done()
bot.delete_conversation(bot.get_conversation_info())
elif get_config('provider') == 'mistral':
client = MistralClient(
api_key=get_config('api_key')
)
completions = client.chat(
model=get_config('model') or 'mistral-large-latest',
messages=get_messages_for_mistral(messages),
max_tokens=1024,
temperature=float(get_config('temperature') or '0.2'),
)
response = completions.choices[0].message.content.strip(' `')
else:
completions = get_openai_client().chat.completions.create(
model=get_config('model') or 'gpt-4',
Expand Down

0 comments on commit 5655054

Please sign in to comment.