Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add mistralai support #192

Merged
merged 6 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mle/agents/advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from rich.console import Console

from mle.function import *
from mle.utils import get_config, print_in_box
from mle.utils import get_config, print_in_box, clean_json_string


def process_report(requirement: str, suggestions: dict):
Expand Down Expand Up @@ -118,7 +118,10 @@ def suggest(self, requirement):
)

self.chat_history.append({"role": "assistant", "content": text})
suggestions = json.loads(text)
try:
suggestions = json.loads(text)
except json.JSONDecodeError as e:
suggestions = clean_json_string(text)

return process_report(requirement, suggestions)

Expand Down Expand Up @@ -185,7 +188,7 @@ def clarify_dataset(self, dataset: str):
text = self.model.query(chat_history)
chat_history.append({"role": "assistant", "content": text})
if "yes" in text.lower():
return
return dataset

# recommend some datasets based on the users' description
user_prompt = f"""
Expand Down
5 changes: 2 additions & 3 deletions mle/agents/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import questionary
from rich.console import Console

from mle.utils import print_in_box
from mle.utils import print_in_box, clean_json_string


def process_plan(plan_dict: dict):
Expand Down Expand Up @@ -100,8 +100,7 @@ def plan(self, user_prompt):
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"Error parsing JSON response: {e}")
sys.exit(1)
return clean_json_string(text)

def interact(self, user_prompt):
"""
Expand Down
9 changes: 7 additions & 2 deletions mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def new(name):

platform = questionary.select(
"Which language model platform do you want to use?",
choices=['OpenAI', 'Ollama', 'Claude']
choices=['OpenAI', 'Ollama', 'Claude', 'MistralAI']
).ask()

api_key = None
Expand All @@ -178,6 +178,12 @@ def new(name):
if not api_key:
console.log("API key is required. Aborted.")
return

elif platform == 'MistralAI':
api_key = questionary.password("What is your MistralAI API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return

search_api_key = questionary.password("What is your Tavily API key? (if no, the web search will be disabled)").ask()
if search_api_key:
Expand All @@ -193,7 +199,6 @@ def new(name):
'api_key': api_key,
'search_key': search_api_key
}, outfile, default_flow_style=False)

# init the memory
Memory(project_dir)

Expand Down
114 changes: 113 additions & 1 deletion mle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MODEL_OLLAMA = 'Ollama'
MODEL_OPENAI = 'OpenAI'
MODEL_CLAUDE = 'Claude'
MODEL_MISTRAL = 'MistralAI'

class Model(ABC):

Expand Down Expand Up @@ -255,7 +256,6 @@ def query(self, chat_history, **kwargs):
stream=False,
tools=tools,
)

if completion.stop_reason == "tool_use":
for func in completion.content:
if func.type != "tool_use":
Expand Down Expand Up @@ -302,6 +302,116 @@ def stream(self, chat_history, **kwargs):
for chunk in stream.text_stream:
yield chunk

class MistralModel(Model):
def __init__(self, api_key, model, temperature=0.7):
"""
Initialize the Mistral model.
Args:
api_key (str): The Mistral API key.
model (str): The model with version.
temperature (float): The temperature value.
"""
super().__init__()

dependency = "mistralai"
spec = importlib.util.find_spec(dependency)
if spec is not None:
self.mistral = importlib.import_module(dependency).Mistral
else:
raise ImportError(
"It seems you didn't install mistralai. In order to enable the Mistral AI client related features, "
"please make sure mistralai Python package has been installed. "
"More information, please refer to: https://github.com/mistralai/client-python"
)

self.model = model if model else 'mistral-large-latest'
self.model_type = MODEL_MISTRAL
self.temperature = temperature
self.client = self.mistral(api_key=api_key)
self.func_call_history = []

def _convert_functions_to_tools(self, functions):
"""
Convert OpenAI-style functions to Mistral-style tools.
"""
tools = []
for func in functions:
tool = {
"type": "function",
"function": {
"name": func["name"],
"description": func.get("description", ""),
"parameters": func["parameters"]
}
}
tools.append(tool)
return tools

def query(self, chat_history, **kwargs):
"""
Query the LLM model.

Args:
chat_history: The context (chat history).
"""
functions = kwargs.get("functions",[])
tools = self._convert_functions_to_tools(functions)
tool_choice = kwargs.get('tool_choice', 'any')
parameters = kwargs
completion = self.client.chat.complete(
model=self.model,
messages=chat_history,
temperature=self.temperature,
stream=False,
tools=tools,
tool_choice=tool_choice,
)
resp = completion.choices[0].message
if resp.tool_calls:
for tool_call in resp.tool_calls:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
print("[MLE FUNC CALL]: ", function_name)
self.func_call_history.append({"name": function_name, "arguments": arguments})
# avoid the multiple search function calls
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
if len(search_attempts) > 3:
parameters['tool_choice'] = "none"
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name, "tool_call_id":tool_call.id})
return self.query(chat_history, **parameters)
else:
return resp.content

def stream(self, chat_history, **kwargs):
"""
Stream the output from the LLM model.
Args:
chat_history: The context (chat history).
"""
functions = kwargs.get("functions",[])
tools = self._convert_functions_to_tools(functions)
tool_choice = kwargs.get('tool_choice', 'any')
for chunk in self.client.chat.complete(
model=self.model,
messages=chat_history,
temperature=self.temperature,
stream=True,
tools=tools,
tool_choice=tool_choice
):
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.function.name:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name})
yield from self.stream(chat_history, **kwargs)
else:
yield chunk.choices[0].delta.content

def load_model(project_dir: str, model_name: str):
"""
Expand All @@ -317,4 +427,6 @@ def load_model(project_dir: str, model_name: str):
return ClaudeModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_OLLAMA:
return OllamaModel(model=model_name)
if config['platform'] == MODEL_MISTRAL:
return MistralModel(api_key=config['api_key'], model=model_name)
return None
1 change: 1 addition & 0 deletions mle/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .system import *
from .cache import *
from .memory import *
from .data import *
13 changes: 13 additions & 0 deletions mle/utils/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import re


def clean_json_string(input_string):
"""
clean the json string
:input_string: the input json string
"""
cleaned = input_string.strip()
cleaned = re.sub(r'^```\s*json?\s*', '', cleaned)
cleaned = re.sub(r'\s*```\s*$', '', cleaned)
parsed_json = json.loads(cleaned)
return parsed_json
Loading