diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 7b61ac69f..30ff94572 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -61,6 +61,15 @@ ip_expiration_dict = defaultdict(lambda: 0) +# Information about custom OpenAI compatible API models. +# JSON file format: +# { +# "vicuna-7b": { +# "model_name": "vicuna-7b-v1.5", +# "api_base": "http://8.8.8.55:5555/v1", +# "api_key": "password" +# }, +# } openai_compatible_models_info = {} @@ -394,11 +403,11 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request) yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 try: - for data in stream_iter: + for i, data in enumerate(stream_iter): if data["error_code"] == 0: + if i % 5 != 0: # reduce gradio's overhead + continue output = data["text"].strip() - if "vicuna" in model_name: - output = post_process_code(output) conv.update_last_message(output + "▌") yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 else: @@ -412,6 +421,11 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request) enable_btn, ) return + output = data["text"].strip() + if "vicuna" in model_name: + output = post_process_code(output) + conv.update_last_message(output) + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 except requests.exceptions.RequestException as e: conv.update_last_message( f"{SERVER_ERROR_MSG}\n\n" @@ -439,10 +453,6 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request) ) return - # Delete "▌" - conv.update_last_message(conv.messages[-1][-1][:-1]) - yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - finish_tstamp = time.time() logger.info(f"{output}")