Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Yi Handler #543

Merged
merged 15 commits into from
Jul 26, 2024
2 changes: 2 additions & 0 deletions berkeley-function-call-leaderboard/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ export FIRE_WORKS_API_KEY=XXXXXX
export ANTHROPIC_API_KEY=XXXXXX
export COHERE_API_KEY=XXXXXX
export NVIDIA_API_KEY=nvapi-XXXXXX
export YI_API_KEY=XXXXXX
```

If decided to run OSS model, the generation script uses vllm and therefore requires GPU for hosting and inferencing. If you have questions or concerns about evaluating OSS models, please reach out to us in our [discord channel](https://discord.gg/grXXvj9Whz).
Expand Down Expand Up @@ -116,6 +117,7 @@ Below is *a table of models we support* to run our leaderboard evaluation agains
|nvidia/nemotron-4-340b-instruct| Prompt|
|THUDM/glm-4-9b-chat 💻| Function Calling|
|ibm-granite/granite-20b-functioncalling 💻| Function Calling|
|yi-large-fc | Function Calling|

Here {MODEL} 💻 means the model needs to be hosted locally and called by vllm, {MODEL} means the models that are called API calls. For models with a trailing `-FC`, it means that the model supports function-calling feature. You can check out the table summarizing feature supports among different models [here](https://gorilla.cs.berkeley.edu/blogs/8_berkeley_function_calling_leaderboard.html#prompt).

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,12 @@
"THUDM",
"glm-4",
],
"yi-large-fc": [
"yi-large (FC)",
"https://platform.01.ai/",
HuanzhiMao marked this conversation as resolved.
Show resolved Hide resolved
"01.AI",
"Proprietary",
],
}

INPUT_PRICE_PER_MILLION_TOKEN = {
Expand Down Expand Up @@ -437,6 +443,7 @@
"command-r-plus": 3,
"command-r-plus-FC-optimized": 3,
"command-r-plus-optimized": 3,
"yi-large-fc": 3,
}

OUTPUT_PRICE_PER_MILLION_TOKEN = {
Expand Down Expand Up @@ -478,6 +485,7 @@
"command-r-plus": 15,
"command-r-plus-FC-optimized": 15,
"command-r-plus-optimized": 15,
"yi-large-fc": 3,
}

# The latency of the open-source models are hardcoded here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
"command-r-plus-FC-optimized",
"THUDM/glm-4-9b-chat",
"ibm-granite/granite-20b-functioncalling",
"yi-large-fc",
]

TEST_CATEGORIES = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from model_handler.granite_handler import GraniteHandler
from model_handler.nvidia_handler import NvidiaHandler
from model_handler.glm_handler import GLMHandler
from model_handler.yi_handler import YiHandler


handler_map = {
"gorilla-openfunctions-v0": GorillaHandler,
Expand Down Expand Up @@ -80,5 +82,6 @@
"snowflake/arctic": ArcticHandler,
"ibm-granite/granite-20b-functioncalling": GraniteHandler,
"nvidia/nemotron-4-340b-instruct": NvidiaHandler,
"THUDM/glm-4-9b-chat": GLMHandler
"THUDM/glm-4-9b-chat": GLMHandler,
"yi-large-fc": YiHandler,
}
81 changes: 81 additions & 0 deletions berkeley-function-call-leaderboard/model_handler/yi_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from model_handler.handler import BaseHandler
from model_handler.model_style import ModelStyle
from model_handler.utils import (
convert_to_tool,
convert_to_function_call,
augment_prompt_by_languge,
language_specific_pre_processing,
)
from model_handler.constant import GORILLA_TO_OPENAPI
from openai import OpenAI
import os, time, json


class YiHandler(BaseHandler):
def __init__(self, model_name, temperature=0.0, top_p=1, max_tokens=1000) -> None:
super().__init__(model_name, temperature, top_p, max_tokens)
self.model_style = ModelStyle.OpenAI
self.base_url = "https://api.lingyiwanwu.com/v1"
fantasist marked this conversation as resolved.
Show resolved Hide resolved
self.client = OpenAI(base_url=self.base_url, api_key=os.getenv("YI_API_KEY"))

def inference(self, prompt, functions, test_category):
prompt = augment_prompt_by_languge(prompt, test_category)
functions = language_specific_pre_processing(functions, test_category)
if type(functions) is not list:
functions = [functions]

message = [{"role": "user", "content": "Questions:" + prompt}]
oai_tool = convert_to_tool(
functions, GORILLA_TO_OPENAPI, self.model_style, test_category
)
start_time = time.time()
if len(oai_tool) > 0:
response = self.client.chat.completions.create(
messages=message,
model=self.model_name,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
tools=oai_tool,
)
else:
response = self.client.chat.completions.create(
messages=message,
model=self.model_name,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
)
latency = time.time() - start_time
try:
result = [
{func_call.function.name: func_call.function.arguments}
for func_call in response.choices[0].message.tool_calls
]
except Exception as e:
result = response.choices[0].message.content

metadata = {}
metadata["input_tokens"] = response.usage.prompt_tokens
metadata["output_tokens"] = response.usage.completion_tokens
metadata["latency"] = latency
return result, metadata

def decode_ast(self,result,language="Python"):
decoded_output = []
for invoked_function in result:
name = list(invoked_function.keys())[0]
params = json.loads(invoked_function[name])
if language == "Python":
pass
else:
# all values of the json are casted to string for java and javascript
for key in params:
params[key] = str(params[key])
decoded_output.append({name: params})

return decoded_output

def decode_execute(self,result):
function_call = convert_to_function_call(result)
return function_call