Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use conversation template for api proxy, fix eventsource format #2383

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Use coversation template from fastchat for api proxy
Fix eventsource format
  • Loading branch information
zeyugao committed Jul 26, 2023
commit ea5a7fbc9532f4ae3f4f55bda0078ab9ce64729e
67 changes: 46 additions & 21 deletions examples/server/api_like_OAI.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import argparse
from flask import Flask, jsonify, request, Response
from flask_cors import CORS
import urllib.parse
import requests
import time
import json
from fastchat import conversation


app = Flask(__name__)
CORS(app)

parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
parser.add_argument("--chat-prompt-model", type=str, help="Set the model", default="")
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
Expand All @@ -29,25 +33,46 @@ def is_present(json, key):
return True


use_conversation_template = args.chat_prompt_model != ""

#convert chat to prompt
def convert_chat(messages):
prompt = "" + args.chat_prompt.replace("\\n", "\n")

system_n = args.system_name.replace("\\n", "\n")
user_n = args.user_name.replace("\\n", "\n")
ai_n = args.ai_name.replace("\\n", "\n")
stop = args.stop.replace("\\n", "\n")
if use_conversation_template:
conv = conversation.get_conv_template(args.chat_prompt_model)
stop_token = conv.stop_str
else:
stop_token = args.stop


for line in messages:
if (line["role"] == "system"):
prompt += f"{system_n}{line['content']}"
if (line["role"] == "user"):
prompt += f"{user_n}{line['content']}"
if (line["role"] == "assistant"):
prompt += f"{ai_n}{line['content']}{stop}"
prompt += ai_n.rstrip()
#convert chat to prompt
def convert_chat(messages):
if use_conversation_template:
conv = conversation.get_conv_template(args.chat_prompt_model)
for line in messages:
if (line["role"] == "system"):
try:
conv.set_system_msg(line["content"])
except Exception:
pass
elif (line["role"] == "user"):
conv.append_message(conv.roles[0], line["content"])
elif (line["role"] == "assistant"):
conv.append_message(conv.roles[1], line["content"])
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
else:
prompt = "" + args.chat_prompt.replace("\\n", "\n")
system_n = args.system_name.replace("\\n", "\n")
user_n = args.user_name.replace("\\n", "\n")
ai_n = args.ai_name.replace("\\n", "\n")
stop = stop_token.replace("\\n", "\n")

for line in messages:
if (line["role"] == "system"):
prompt += f"{system_n}{line['content']}"
if (line["role"] == "user"):
prompt += f"{user_n}{line['content']}"
if (line["role"] == "assistant"):
prompt += f"{ai_n}{line['content']}{stop}"
prompt += ai_n.rstrip()

return prompt

Expand All @@ -69,8 +94,8 @@ def make_postData(body, chat=False, stream=False):
if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"]
if(is_present(body, "seed")): postData["seed"] = body["seed"]
if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()]
if (args.stop != ""):
postData["stop"] = [args.stop]
if stop_token: # "" or None
postData["stop"] = [stop_token]
else:
postData["stop"] = []
if(is_present(body, "stop")): postData["stop"] += body["stop"]
Expand Down Expand Up @@ -173,12 +198,12 @@ def generate():
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
time_now = int(time.time())
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
yield 'data: {}\n'.format(json.dumps(resData))
yield 'data: {}\n\n'.format(json.dumps(resData))
for line in data.iter_lines():
if line:
decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
yield 'data: {}\n'.format(json.dumps(resData))
yield 'data: {}\n\n'.format(json.dumps(resData))
return Response(generate(), mimetype='text/event-stream')


Expand Down Expand Up @@ -212,7 +237,7 @@ def generate():
if line:
decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
yield 'data: {}\n'.format(json.dumps(resData))
yield 'data: {}\n\n'.format(json.dumps(resData))
return Response(generate(), mimetype='text/event-stream')

if __name__ == '__main__':
Expand Down