Skip to content

Commit 12175f0

Browse files
authored
Merge pull request #52 from MooreThreads/vllm_musa
feat(vllm_musa): 添加vllm_musa 单机推理脚本
2 parents 0faf192 + 9f66019 commit 12175f0

File tree

3 files changed

+477
-0
lines changed

3 files changed

+477
-0
lines changed

vllm_musa/gradio_demo/app.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import gradio as gr
2+
import requests
3+
import json
4+
import argparse
5+
import time
6+
import gradio_musa
7+
8+
9+
def parse_args():
10+
# 创建 ArgumentParser 对象
11+
parser = argparse.ArgumentParser(description="Start the vLLM server.")
12+
13+
# 添加命令行参数
14+
parser.add_argument(
15+
"--ip",
16+
type=str,
17+
default="0.0.0.0", # 如果没有传入--ip,使用默认值
18+
help="IP address to bind to (default: 0.0.0.0)"
19+
)
20+
21+
parser.add_argument(
22+
"--port",
23+
type=str,
24+
default="8000", # 如果没有传入--port,使用默认值
25+
help="Port number to use (default: 8000)"
26+
)
27+
parser.add_argument(
28+
"--model-name",
29+
type=str,
30+
help="Model Name"
31+
)
32+
33+
# 解析传入的参数
34+
args = parser.parse_args()
35+
return args
36+
37+
args = parse_args()
38+
# 配置 vLLM 推理服务的地址和模型名
39+
VLLM_API_URL = f"http://{args.ip}:{args.port}/v1/chat/completions"
40+
MODEL_NAME = args.model_name
41+
42+
43+
# ✅ 流式请求函数
44+
def chat_with_model_streaming(user_input, history):
45+
messages = [{"role": "system", "content": "You are a helpful assistant."}]
46+
messages.append({"role": "user", "content": user_input})
47+
48+
payload = {
49+
"model": MODEL_NAME,
50+
"messages": messages,
51+
"stream": True # ✅ 启用流式输出
52+
}
53+
54+
history = history or [] # 初始化历史记录
55+
bot_response = "" # 存储逐步生成的回答
56+
57+
# ✅ 记录开始时间
58+
start_time = time.time()
59+
token_count = 0 # ✅ 记录生成的 Token 数量
60+
first_token_time = None
61+
62+
try:
63+
# ✅ 使用 requests 的流式请求
64+
with requests.post(VLLM_API_URL, json=payload, stream=True) as response:
65+
response.raise_for_status()
66+
67+
# ✅ 逐块解析流式响应
68+
for chunk in response.iter_lines():
69+
if chunk:
70+
chunk_str = chunk.decode("utf-8").strip()
71+
if chunk_str.startswith("data: "):
72+
chunk_data = chunk_str[6:] # 去掉 "data: " 前缀
73+
if chunk_data != "[DONE]":
74+
try:
75+
chunk_json = json.loads(chunk_data)
76+
delta = chunk_json["choices"][0]["delta"]
77+
if "content" in delta:
78+
bot_response += delta["content"]
79+
# ✅ 逐步更新聊天记录
80+
token_count += 1 # ✅ 每个 Token 计数
81+
if first_token_time is None and token_count > 0:
82+
first_token_time = time.time()
83+
84+
yield history + [(user_input, bot_response)], "", "推理中..."
85+
except json.JSONDecodeError:
86+
pass
87+
# ✅ 记录结束时间 & 计算时长
88+
first_token_latency = first_token_time - start_time if first_token_time is not None else 0
89+
elapsed_time = time.time() - first_token_time
90+
tps = token_count / elapsed_time if elapsed_time > 0 else 0 # ✅ 计算 Tokens Per Second
91+
speed_text = f"⏳ 首字延迟: {first_token_latency:.2f} 秒 | ⏱️ 耗时: {elapsed_time:.2f} 秒 | 🔢 Tokens: {token_count} | ⚡ 速度: {tps:.2f} TPS" # ⏳
92+
yield history + [(user_input, bot_response)], "", speed_text # ✅ 返回推理速度
93+
94+
except Exception as e:
95+
bot_response = f"❌ 推理失败: {str(e)}"
96+
yield history + [(user_input, bot_response)], ""
97+
98+
99+
100+
# ✅ 清除聊天记录 & 计时器
101+
def clear_chat():
102+
return [], "", "⏳ 首字延迟: 0.00 秒 | ⏱️ 耗时: 0.00 秒 | 🔢 Tokens: 0 | ⚡ 速度: 0.00 TPS" # ✅ 清空所有 UI
103+
104+
# 构建 Gradio 界面
105+
with gradio_musa.Blocks() as demo:
106+
# gr.Markdown("## 💬 Web UI 接入 vLLM 模型(流式输出)")
107+
chatbot = gr.Chatbot(label="Running on MTT S4000")
108+
msg_input = gr.Textbox(placeholder="请输入你的问题", label="输入...", lines=1, autofocus=True)
109+
110+
speed_display = gr.Textbox(label="推理速度", value="⏳ 首字延迟: 0.00 秒 | ⏱️ 耗时: 0.00 秒 | 🔢 Tokens: 0 | ⚡ 速度: 0.00 TPS", interactive=False) # >✅ 显示推理速度
111+
112+
# clear = gr.Button("清除")
113+
# submit = gr.Button("提交")
114+
with gr.Row():
115+
submit_btn = gr.Button(value="提交")
116+
clear_btn = gr.Button("清除历史") # ✅ 添加清除按钮
117+
118+
# ✅ 使用流式函数
119+
msg_input.submit(chat_with_model_streaming, inputs=[msg_input, chatbot], outputs=[chatbot, msg_input, speed_display]) # ✅ 按 Enter 触发
120+
submit_btn.click(chat_with_model_streaming, inputs=[msg_input, chatbot], outputs=[chatbot, msg_input, speed_display]) # ✅ 按按钮触发
121+
clear_btn.click(clear_chat, inputs=[], outputs=[chatbot, msg_input, speed_display]) # ✅ 清除聊天 & 计时
122+
123+
demo.queue() # ✅ 允许流式数据传输
124+
demo.launch(server_name=args.ip)
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import gradio as gr
2+
3+
4+
TITLE=""
5+
6+
TOP = """\
7+
<div class="top">
8+
<div class="top-container">
9+
<img class="logo" width="140" height="37" src="https://kuae-playground.mthreads.com/image/logo@2x.png">
10+
</div>
11+
</div>"""
12+
13+
FOOTER = '''\
14+
<div class="footer">
15+
<span>Copyright © 2024-2025 摩尔线程版权所有 京公网安备 11010802035174号 京ICP证2020041674号-2</span>
16+
</div>'''
17+
js_change_title = '''\
18+
window.onload = function() {
19+
document.title = "''' + TITLE + '''";
20+
}'''
21+
HEADER = TOP + "<h1>" + TITLE + "</h1><p>"
22+
23+
24+
25+
CSS='''body {
26+
margin: 0;
27+
background: #F8F8F8;
28+
font-size: 22px;
29+
color: #666666;
30+
}
31+
p {
32+
font-size: 16px;
33+
}
34+
.top {
35+
left: 0;
36+
top: 0;
37+
height: 3.83%;
38+
opacity: 1;
39+
justify-content: center;
40+
background: white;
41+
}
42+
.top-container {
43+
/* 原样式 */
44+
margin: 0 auto;
45+
max-width: 1500px;
46+
padding: 10px;
47+
display: flex;
48+
padding: 16px 0;
49+
overflow: hidden;
50+
}
51+
.logo {
52+
margin: 0;
53+
padding: 0 20px;
54+
}
55+
h2 {
56+
position: relative;
57+
margin: 0;
58+
font-size: 21px;
59+
font-weight: normal;
60+
line-height: 20px;
61+
letter-spacing: 0;
62+
padding: 5px 20px;
63+
color: #666666;
64+
}
65+
.top-container>h2:before {
66+
background: #dcdfe6;
67+
/* 设置背景颜色为浅灰色 */
68+
content: "";
69+
/* 伪元素的内容,这里为空,意味着不会显示任何文本 */
70+
height: 16px;
71+
/* 设置伪元素的高度为 16px */
72+
left: 0;
73+
/* 设置伪元素的左边距为 0,即与其定位的父元素(这里是 h2)的左边对齐 */
74+
position: absolute;
75+
/* 设置位置为绝对定位,从而允许我们根据父元素进行准确放置 */
76+
top: 50%;
77+
/* 设置伪元素的上边距为 50%,这样伪元素的顶部将对齐到父元素的中间 */
78+
transform: translateY(-50%);
79+
/* 将伪元素向上移动自身高度的一半,即使其完全居中于父元素 */
80+
width: 1px;
81+
/* 设置伪元素的宽度为 1px,表现为一个细线 */
82+
}
83+
h1 {
84+
text-align: center;
85+
display: block;
86+
}
87+
88+
.footer {
89+
padding: 20px;
90+
text-align: center;
91+
font-size: 16px;
92+
}
93+
94+
.footer .logo {
95+
display: inline-block;
96+
/* 内联块元素使其与文本对齐 */
97+
margin-right: 10px;
98+
/* 右边距 */
99+
}
100+
101+
.footer a {
102+
color: #666666;
103+
text-decoration: none;
104+
}
105+
footer {
106+
visible;
107+
}'''
108+
109+
110+
class Blocks(gr.Blocks):
111+
def __init__(self, **kwargs):
112+
super().__init__(css=CSS, js=js_change_title)
113+
114+
def __enter__(self):
115+
r = super().__enter__()
116+
gr.HTML(HEADER)
117+
return r
118+
119+
def __exit__(self, exc_type, exc_value, traceback):
120+
gr.HTML(FOOTER)
121+
return super().__exit__(exc_type, exc_value, traceback)

0 commit comments

Comments
 (0)