Skip to content

Commit 0b5ae0d

Browse files
authored
Merge pull request xusenlinzy#232 from lzhfe/master
Fix vllm stream function call
2 parents d2a9e6c + ccddd76 commit 0b5ae0d

File tree

1 file changed

+45
-20
lines changed

1 file changed

+45
-20
lines changed

api/vllm_routes/chat.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
)
2020
from openai.types.chat.chat_completion import Choice
2121
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
22-
from openai.types.chat.chat_completion_chunk import ChoiceDelta
22+
from openai.types.chat.chat_completion_chunk import (
23+
ChoiceDelta,
24+
ChoiceDeltaFunctionCall,
25+
ChoiceDeltaToolCall
26+
)
2327
from openai.types.chat.chat_completion_message import FunctionCall
2428
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
2529
from openai.types.completion_usage import CompletionUsage
@@ -63,7 +67,7 @@ async def create_chat_completion(
6367
generator = engine.generate(params, request_id)
6468

6569
if request.stream:
66-
iterator = create_chat_completion_stream(generator, params, request_id)
70+
iterator = create_chat_completion_stream(generator, params, request_id, engine)
6771
send_chan, recv_chan = anyio.create_memory_object_stream(10)
6872
return EventSourceResponse(
6973
recv_chan,
@@ -146,7 +150,7 @@ async def create_chat_completion(
146150
)
147151

148152

149-
async def create_chat_completion_stream(generator: AsyncIterator, params: Dict[str, Any], request_id: str) -> AsyncIterator:
153+
async def create_chat_completion_stream(generator: AsyncIterator, params: Dict[str, Any], request_id: str, engine: VllmEngine) -> AsyncIterator:
150154
n = params.get("n", 1)
151155
for i in range(n):
152156
# First chunk with role
@@ -164,6 +168,9 @@ async def create_chat_completion_stream(generator: AsyncIterator, params: Dict[s
164168
object="chat.completion.chunk",
165169
)
166170

171+
functions = params.get("functions", None)
172+
tools = params.get("tools", None)
173+
167174
previous_texts = [""] * n
168175
previous_num_tokens = [0] * n
169176
async for res in generator:
@@ -176,10 +183,43 @@ async def create_chat_completion_stream(generator: AsyncIterator, params: Dict[s
176183
previous_texts[i] = output.text
177184
previous_num_tokens[i] = len(output.token_ids)
178185

186+
finish_reason = output.finish_reason
187+
delta = None
188+
189+
if finish_reason is None:
190+
delta = ChoiceDelta(content=delta_text)
191+
elif functions or tools:
192+
call_info = None
193+
try:
194+
res, call_info = engine.prompt_adapter.parse_assistant_response(
195+
output.text, functions, tools,
196+
)
197+
except Exception as e:
198+
traceback.print_exc()
199+
logger.warning("Failed to parse tool call")
200+
201+
if isinstance(call_info, dict) and "arguments" in call_info:
202+
finish_reason = "function_call"
203+
function_call = ChoiceDeltaFunctionCall(**call_info)
204+
delta = ChoiceDelta(
205+
role="assistant",
206+
content=delta_text,
207+
function_call=function_call
208+
)
209+
elif isinstance(call_info, dict) and "function" in call_info:
210+
finish_reason = "tool_calls"
211+
call_info["index"] = 0
212+
tool_calls = [model_parse(ChoiceDeltaToolCall, call_info)]
213+
delta = ChoiceDelta(
214+
role="assistant",
215+
content=delta_text,
216+
tool_calls=tool_calls,
217+
)
218+
179219
choice = ChunkChoice(
180220
index=i,
181-
delta=ChoiceDelta(content=delta_text),
182-
finish_reason=output.finish_reason,
221+
delta=delta or ChoiceDelta(),
222+
finish_reason=finish_reason,
183223
logprobs=None,
184224
)
185225
yield ChatCompletionChunk(
@@ -189,18 +229,3 @@ async def create_chat_completion_stream(generator: AsyncIterator, params: Dict[s
189229
model=params.get("model", "llm"),
190230
object="chat.completion.chunk",
191231
)
192-
193-
if output.finish_reason is not None:
194-
choice = ChunkChoice(
195-
index=i,
196-
delta=ChoiceDelta(),
197-
finish_reason="stop",
198-
logprobs=None,
199-
)
200-
yield ChatCompletionChunk(
201-
id=request_id,
202-
choices=[choice],
203-
created=int(time.time()),
204-
model=params.get("model", "llm"),
205-
object="chat.completion.chunk",
206-
)

0 commit comments

Comments
 (0)