Skip to content

Commit

Permalink
fix stream=false doesn't work issue (opea-project#124)
Browse files Browse the repository at this point in the history
Signed-off-by: letonghan <letong.han@intel.com>
Signed-off-by: XuhuiRen <xuhui.ren@intel.com>
  • Loading branch information
letonghan authored and XuhuiRen committed May 31, 2024
1 parent 1f4c6c9 commit d5c954d
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888):

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
Expand All @@ -126,7 +127,7 @@ async def handle_request(self, request: Request):
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=chat_request.stream if chat_request.stream else True,
streaming=stream_opt,
)
await self.megaservice.schedule(initial_inputs={"text": prompt}, llm_parameters=parameters)
for node, response in self.megaservice.result_dict.items():
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888):

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
Expand All @@ -167,7 +169,7 @@ async def handle_request(self, request: Request):
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=chat_request.stream if chat_request.stream else True,
streaming=stream_opt,
)
await self.megaservice.schedule(initial_inputs={"query": prompt}, llm_parameters=parameters)
for node, response in self.megaservice.result_dict.items():
Expand Down Expand Up @@ -247,6 +249,7 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888):

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
Expand All @@ -255,7 +258,7 @@ async def handle_request(self, request: Request):
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=chat_request.stream if chat_request.stream else True,
streaming=stream_opt,
)
await self.megaservice.schedule(initial_inputs={"query": prompt}, llm_parameters=parameters)
for node, response in self.megaservice.result_dict.items():
Expand Down

0 comments on commit d5c954d

Please sign in to comment.