Skip to content

Commit

Permalink
feat: support for gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
Realiserad committed Apr 6, 2024
1 parent ac61007 commit 67c4754
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 14 deletions.
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

Originally based on [Tom Dörr's `fish.codex` repository](https://github.com/tom-doerr/codex.fish),
but with some additional functionality. It uses the [chat completions API endpoint](https://platform.openai.com/docs/api-reference/chat/create)
and can be hooked up to OpenAI, Azure OpenAI or a self-hosted LLM behind any
OpenAI-compatible API.
and can be hooked up to Google, OpenAI, Azure OpenAI or a self-hosted LLM
behind any OpenAI-compatible API.

Continuous integration is performed against Azure OpenAI.

Expand Down Expand Up @@ -60,6 +60,17 @@ model = <your deployment name>
api_key = <your API key>
```

If you use [Gemini](https://deepmind.google/technologies/gemini):

```ini
[fish-ai]
configuration = gemini

[gemini]
provider = google
api_key = <your API key>
```

### Install `fish-ai`

Install the plugin. You can install it using [`fisher`](https://github.com/jorgebucaran/fisher).
Expand Down
4 changes: 2 additions & 2 deletions conf.d/fish_ai.fish
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ end
##
function _fish_ai_install --on-event fish_ai_install
python3 -m venv ~/.fish-ai
~/.fish-ai/bin/pip install -qq openai
~/.fish-ai/bin/pip install -qq openai google-generativeai
end

function _fish_ai_update --on-event fish_ai_update
~/.fish-ai/bin/pip install -qq --upgrade openai
~/.fish-ai/bin/pip install -qq --upgrade openai google-generativeai
end

function __fish_ai_uninstall --on-event fish_ai_uninstall
Expand Down
58 changes: 48 additions & 10 deletions functions/_fish_ai_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from openai import OpenAI
from openai import AzureOpenAI
import google.generativeai as genai
from configparser import ConfigParser
from os import path
import logging
Expand Down Expand Up @@ -49,7 +50,7 @@ def get_config(key):
return config.get(section=active_section, option=key)


def get_client():
def get_openai_client():
if (get_config('provider') == 'azure'):
return AzureOpenAI(
azure_endpoint=get_config('server'),
Expand All @@ -72,18 +73,55 @@ def get_client():
.format(get_config('provider')))


def create_message_history(messages):
"""
Create message history which can be used with Gemini.
Google uses a different chat history format than OpenAI.
The message content should be put in a parts array and
system messages are not supported.
"""
outputs = []
system_messages = []
for message in messages:
if message.get('role') == 'system':
system_messages.append(message.get('content'))
for i in range(len(messages) - 1):
message = messages[i]
if message.get('role') == 'user':
outputs.append({
'role': 'user',
'parts': system_messages + [message.get('content')] if i == 0
else [message.get('content')]
})
elif message.get('role') == 'assistant':
outputs.append({
'role': 'model',
'parts': [message.get('content')]
})
return outputs


def get_response(messages):
start_time = time_ns()
completions = get_client().chat.completions.create(
model=get_config('model'),
max_tokens=4096,
messages=messages,
stream=False,
temperature=float(get_config('temperature') or '0.2'),
n=1,
)

if get_config('provider') == 'google':
genai.configure(api_key=get_config('api_key'))
model = genai.GenerativeModel(get_config('model') or 'gemini-pro')
chat = model.start_chat(history=create_message_history(messages))
response = (chat.send_message(messages[-1].get('content'))
.text.strip(' `'))
else:
completions = get_openai_client().chat.completions.create(
model=get_config('model'),
max_tokens=4096,
messages=messages,
stream=False,
temperature=float(get_config('temperature') or '0.2'),
n=1,
)
response = completions.choices[0].message.content.strip(' `')

end_time = time_ns()
response = completions.choices[0].message.content.strip(' `')
get_logger().debug('Response received from backend: ' + response)
get_logger().debug('Processing time: ' +
str(round((end_time - start_time) / 1000000)) + ' ms.')
Expand Down

0 comments on commit 67c4754

Please sign in to comment.