Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] deepseek support #212

Merged
merged 2 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def chat():
if not check_config(console):
return

model = load_model(os.getcwd(), "gpt-4o")
model = load_model(os.getcwd())
coder = CodeAgent(model)

while True:
Expand Down Expand Up @@ -193,7 +193,7 @@ def new(name):

platform = questionary.select(
"Which language model platform do you want to use?",
choices=['OpenAI', 'Ollama', 'Claude', 'MistralAI']
choices=['OpenAI', 'Ollama', 'Claude', 'MistralAI', 'DeepSeek']
).ask()

api_key = None
Expand All @@ -214,6 +214,11 @@ def new(name):
if not api_key:
console.log("API key is required. Aborted.")
return
elif platform == 'DeepSeek':
api_key = questionary.password("What is your DeepSeek API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return

search_api_key = questionary.password("What is your Tavily API key? (if no, the web search will be disabled)").ask()
if search_api_key:
Expand Down
124 changes: 120 additions & 4 deletions mle/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import json
import copy
import importlib.util
import json
from abc import ABC, abstractmethod
from typing import Optional

from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
from mle.utils import get_config
from mle.function import get_function, process_function_name, SEARCH_FUNCTIONS

MODEL_OLLAMA = 'Ollama'
MODEL_OPENAI = 'OpenAI'
MODEL_CLAUDE = 'Claude'
MODEL_MISTRAL = 'MistralAI'
MODEL_DEEPSEEK = 'DeepSeek'

class Model(ABC):

Expand Down Expand Up @@ -353,7 +356,7 @@ def _convert_functions_to_tools(self, functions):
}
tools.append(tool)
return tools

def query(self, chat_history, **kwargs):
"""
Query the LLM model.
Expand Down Expand Up @@ -420,7 +423,118 @@ def stream(self, chat_history, **kwargs):
else:
yield chunk.choices[0].delta.content

def load_model(project_dir: str, model_name: str):

class DeepSeekModel(Model):
def __init__(self, api_key, model, temperature=0.7):
"""
Initialize the DeepSeek model.
Args:
api_key (str): The DeepSeek API key.
model (str): The model with version.
temperature (float): The temperature value.
"""
super().__init__()

dependency = "openai"
spec = importlib.util.find_spec(dependency)
if spec is not None:
self.openai = importlib.import_module(dependency).OpenAI
else:
raise ImportError(
"It seems you didn't install openai. In order to enable the OpenAI client related features, "
"please make sure openai Python package has been installed. "
"More information, please refer to: https://openai.com/product"
)
self.model = model if model else "deepseek-coder"
self.model_type = MODEL_DEEPSEEK
self.temperature = temperature
self.client = self.openai(
api_key=api_key, base_url="https://api.deepseek.com/beta"
)
self.func_call_history = []

def _convert_functions_to_tools(self, functions):
"""
Convert OpenAI-style functions to DeepSeek-style tools.
"""
tools = []
for func in functions:
tool = {
"type": "function",
"function": {
"name": func["name"],
"description": func.get("description", ""),
"parameters": func["parameters"],
},
}
tools.append(tool)
return tools

def query(self, chat_history, **kwargs):
"""
Query the LLM model.

Args:
chat_history: The context (chat history).
"""
functions = kwargs.get("functions", None)
tools = self._convert_functions_to_tools(functions) if functions else None
parameters = kwargs
completion = self.client.chat.completions.create(
model=self.model,
messages=chat_history,
temperature=self.temperature,
stream=False,
tools=tools,
**parameters,
)

resp = completion.choices[0].message
if resp.tool_calls:
for tool_call in resp.tool_calls:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
print("[MLE FUNC CALL]: ", function_name)
self.func_call_history.append({"name": function_name, "arguments": arguments})
# avoid the multiple search function calls
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
if len(search_attempts) > 3:
parameters['tool_choice'] = "none"
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name, "tool_call_id":tool_call.id})
return self.query(chat_history, **parameters)
else:
return resp.content

def stream(self, chat_history, **kwargs):
"""
Stream the output from the LLM model.
Args:
chat_history: The context (chat history).
"""
arguments = ""
function_name = ""
for chunk in self.client.chat.completions.create(
model=self.model,
messages=chat_history,
temperature=self.temperature,
stream=True,
**kwargs,
):
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.function.name:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name})
yield from self.stream(chat_history, **kwargs)
else:
yield chunk.choices[0].delta.content

def load_model(project_dir: str, model_name: Optional[str]=None):
"""
load_model: load the model based on the configuration.
Args:
Expand All @@ -436,4 +550,6 @@ def load_model(project_dir: str, model_name: str):
return OllamaModel(model=model_name)
if config['platform'] == MODEL_MISTRAL:
return MistralModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_DEEPSEEK:
return DeepSeekModel(api_key=config['api_key'], model=model_name)
return None
Loading