From 565505493757073b21007e606f419a6922b34983 Mon Sep 17 00:00:00 2001 From: Bastian Fredriksson Date: Thu, 4 Jul 2024 20:32:04 +0200 Subject: [PATCH] feat: add support for mistral Add support for LLMs hosted by Mistral using Mistral's Python client. --- .devcontainer/requirements-dev.txt | 1 + README.md | 13 ++++++++++++- bump-version.fish | 3 +-- pyproject.toml | 3 ++- src/fish_ai/engine.py | 26 ++++++++++++++++++++++++-- 5 files changed, 40 insertions(+), 6 deletions(-) diff --git a/.devcontainer/requirements-dev.txt b/.devcontainer/requirements-dev.txt index 6faac28..d1ae8c0 100644 --- a/.devcontainer/requirements-dev.txt +++ b/.devcontainer/requirements-dev.txt @@ -6,3 +6,4 @@ pytest pyfakefs iterfzf hugchat +mistralai diff --git a/README.md b/README.md index 8947ca7..9bc2873 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ should run on [any system with Python and git installed](https://github.com/Real Originally based on [Tom Dörr's `fish.codex` repository](https://github.com/tom-doerr/codex.fish), but with some additional functionality. -It can be hooked up to OpenAI, Azure OpenAI, Google, Hugging Face, or a +It can be hooked up to OpenAI, Azure OpenAI, Google, HuggingFace, Mistral or a self-hosted LLM behind any OpenAI-compatible API. If you like it, please add a ⭐. If you don't like it, create a PR. 😆 @@ -105,6 +105,17 @@ model = meta-llama/Meta-Llama-3-70B-Instruct Available models are listed [here](https://huggingface.co/chat/models). Note that 2FA must be disabled on the account. +If you use [Mistral](https://mistral.ai): + +```ini +[fish-ai] +configuration = mistral + +[mistral] +provider = mistral +api_key = +``` + ### Install `fish-ai` Install the plugin. You can install it using [`fisher`](https://github.com/jorgebucaran/fisher). diff --git a/bump-version.fish b/bump-version.fish index b5a7032..342504d 100755 --- a/bump-version.fish +++ b/bump-version.fish @@ -6,8 +6,7 @@ if ! type -q what-bump exit 1 end -set current_branch (git rev-parse --abbrev-ref HEAD) -set start_hash (git show-ref --hash refs/remotes/origin/$current_branch) +set start_hash (git show-ref --hash refs/remotes/origin/main) set current_version (git show $start_hash:pyproject.toml | grep version | head -n 1 | cut -d'"' -f2) set next_version (what-bump --from $current_version $start_hash) if test "$current_version" = "$next_version" diff --git a/pyproject.toml b/pyproject.toml index 8fae375..86d848b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "fish_ai" -version = "0.7.1" +version = "0.8.0" authors = [{ name = "Bastian Fredriksson", email = "realiserad@gmail.com" }] description = "Provides core functionality for fish-ai, an AI plugin for the fish shell." readme = "README.md" @@ -20,6 +20,7 @@ dependencies = [ "simple-term-menu==1.6.4", "iterfzf==1.4.0.51.0", "hugchat==0.4.8", + "mistralai==0.4.2", ] [project.urls] diff --git a/src/fish_ai/engine.py b/src/fish_ai/engine.py index e7bd0e3..d6566e4 100644 --- a/src/fish_ai/engine.py +++ b/src/fish_ai/engine.py @@ -16,6 +16,8 @@ import sys from hugchat import hugchat from hugchat.login import Login +from mistralai.client import MistralClient +from mistralai.models.chat_completion import ChatMessage config = ConfigParser() config.read(path.expanduser('~/.config/fish-ai.ini')) @@ -142,7 +144,7 @@ def get_openai_client(): .format(get_config('provider'))) -def create_message_history(messages): +def get_messages_for_gemini(messages): """ Create message history which can be used with Gemini. Google uses a different chat history format than OpenAI. @@ -170,6 +172,15 @@ def create_message_history(messages): return outputs +def get_messages_for_mistral(messages): + output = [] + for message in messages: + output.append( + ChatMessage(role=message.get('role'), + content=message.get('content'))) + return output + + def create_system_prompt(messages): return '\n\n'.join( list( @@ -192,7 +203,7 @@ def get_response(messages): genai.configure(api_key=get_config('api_key')) model = genai.GenerativeModel( get_config('model') or 'gemini-1.5-flash') - chat = model.start_chat(history=create_message_history(messages)) + chat = model.start_chat(history=get_messages_for_gemini(messages)) generation_config = GenerationConfig( candidate_count=1, temperature=float(get_config('temperature') or '0.2')) @@ -215,6 +226,17 @@ def get_response(messages): response = bot.chat(messages[-1].get('content')).wait_until_done() bot.delete_conversation(bot.get_conversation_info()) + elif get_config('provider') == 'mistral': + client = MistralClient( + api_key=get_config('api_key') + ) + completions = client.chat( + model=get_config('model') or 'mistral-large-latest', + messages=get_messages_for_mistral(messages), + max_tokens=1024, + temperature=float(get_config('temperature') or '0.2'), + ) + response = completions.choices[0].message.content.strip(' `') else: completions = get_openai_client().chat.completions.create( model=get_config('model') or 'gpt-4',