Skip to content

Commit

Permalink
Anthropic message normalization (#537)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakethekoenig authored Mar 7, 2024
1 parent fe167fb commit 2446e0a
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 10 deletions.
37 changes: 27 additions & 10 deletions docs/source/user/alternative_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,36 @@ In addition, Mentat uses the :code:`gpt-4-1106-preview` model by default. When u
.. warning::
Due to changes in the OpenAI Python SDK, you can no longer use :code:`OPENAI_API_BASE` to access the Azure API with Mentat.

Anthropic
---------

Mentat uses the OpenAI SDK to retrieve chat completions. This means that setting the `OPENAI_API_BASE` environment variable is enough to use any model that has the same response schema as OpenAI. To use models with different response schemas, we recommend setting up a litellm proxy as described `here <https://docs.litellm.ai/docs/proxy/quick_start>`__ and pointing `OPENAI_API_BASE` to the proxy. For example with anthropic:

.. code-block:: bash
pip install 'litellm[proxy]'
export ANTHROPIC_API_KEY=sk-*************
litellm --model claude-3-opus-2024-0229 --drop_params
# Should see: Uvicorn running on http://0.0.0.0:8000
.. code-block:: bash
# In ~/.mentat/.env
OPENAI_API_BASE=http://localhost:8000
# In ~/.mentat/.mentat_config.json
{ "model": "claude" }
# or
export OPENAI_API_BASE=http://localhost:8000
mentat
.. note::
Anthropic has slightly different requirements for system messages so you must set your model to a string with "claude" in it. Other than that it isn't important as the exact model is set by the litellm proxy server flag.
🦙 Local Models
---------------
In our experiments we have not found any non-openai models to be as good as even gpt-3.5-turbo with Mentat. That being said it is possible to use Mentat with other models with just a few steps. Mentat uses the OpenAI SDK to retrieve chat completions. This means that setting the `OPENAI_API_BASE` environment variable is enough to use any model that has the same response schema as OpenAI. To use models with different response schemas, we recommend setting up a litellm proxy as described `here <https://docs.litellm.ai/docs/proxy/quick_start>`__ and pointing `OPENAI_API_BASE` to the proxy. You can use local models run with ollama with the following steps:

First run ollama. Replace mixtral with whichever model you want to use.
This works the same as in the previous section but you must install ollama first. Replace mixtral with whichever model you want to use.
.. code-block:: bash
Expand All @@ -36,13 +60,6 @@ Next run the litellm proxy. In another terminal run:
Finally set the OPENAI_API_BASE in the terminal before running mentat.
.. code-block:: bash
# In ~/.mentat/.env
OPENAI_API_BASE=http://localhost:8000
# or
export OPENAI_API_BASE=http://localhost:8000
mentat
.. note::
Expand Down
77 changes: 77 additions & 0 deletions mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@
)
from openai.types.chat import (
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
ChatCompletionContentPartParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.completion_create_params import ResponseFormat
from PIL import Image
Expand Down Expand Up @@ -134,6 +137,74 @@ def count_tokens(message: str, model: str, full_message: bool) -> int:
)


def normalize_messages_for_anthropic(
messages: list[ChatCompletionMessageParam],
) -> list[ChatCompletionMessageParam]:
"""Claude expects the chat to start with at most one system message and afterwards user and system messages to
alternate. This method consolidates all the system messages at the beginning of the conversation into one system
message delimited by "\n"+"-"*80+"\n and turns future system messages into user messages annotated with "System:"
and combines adjacent assistant or user messages into one assistant or user message.
"""
replace_non_leading_systems = list[ChatCompletionMessageParam]()
for i, message in enumerate(messages):
if message["role"] == "system":
if i == 0 or messages[i - 1]["role"] == "system":
replace_non_leading_systems.append(message)
else:
content = "SYSTEM: " + (message["content"] or "")
replace_non_leading_systems.append(
ChatCompletionUserMessageParam(role="user", content=content)
)
else:
replace_non_leading_systems.append(message)

concatenate_adjacent = list[ChatCompletionMessageParam]()
current_role: str = ""
current_content: str = ""
delimiter = "\n" + "-" * 80 + "\n"
for message in replace_non_leading_systems:
if message["role"] == current_role:
current_content += delimiter + str(message["content"])
else:
if current_role == "user":
concatenate_adjacent.append(
ChatCompletionUserMessageParam(
role=current_role, content=current_content
)
)
elif current_role == "system":
concatenate_adjacent.append(
ChatCompletionSystemMessageParam(
role=current_role, content=current_content
)
)
elif current_role == "assistant":
concatenate_adjacent.append(
ChatCompletionAssistantMessageParam(
role=current_role, content=current_content
)
)
current_role = message["role"]
current_content = str(message["content"])

if current_role == "user":
concatenate_adjacent.append(
ChatCompletionUserMessageParam(role=current_role, content=current_content)
)
elif current_role == "system":
concatenate_adjacent.append(
ChatCompletionSystemMessageParam(role=current_role, content=current_content)
)
elif current_role == "assistant":
concatenate_adjacent.append(
ChatCompletionAssistantMessageParam(
role=current_role, content=current_content
)
)

return concatenate_adjacent


def prompt_tokens(messages: list[ChatCompletionMessageParam], model: str):
"""
Returns the number of tokens used by a prompt if it was sent to OpenAI for a chat completion.
Expand Down Expand Up @@ -363,6 +434,10 @@ async def call_llm_api(
session_context = SESSION_CONTEXT.get()
config = session_context.config
cost_tracker = session_context.cost_tracker
session_context.stream.send("here")

if "claude" in config.model:
messages = normalize_messages_for_anthropic(messages)

# Confirm that model has enough tokens remaining.
tokens = prompt_tokens(messages, model)
Expand Down Expand Up @@ -390,6 +465,7 @@ async def call_llm_api(
messages=messages,
temperature=config.temperature,
stream=stream,
max_tokens=4096,
)
else:
response = await self.async_client.chat.completions.create(
Expand All @@ -398,6 +474,7 @@ async def call_llm_api(
temperature=config.temperature,
stream=stream,
response_format=response_format,
max_tokens=4096,
)

# We have to cast response since pyright isn't smart enough to connect
Expand Down

0 comments on commit 2446e0a

Please sign in to comment.