diff --git a/.changeset/mighty-maps-double.md b/.changeset/mighty-maps-double.md new file mode 100644 index 0000000000000..352fd1ed11297 --- /dev/null +++ b/.changeset/mighty-maps-double.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Fix multimodal chatinterface api bug diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index 967bebcd65fbd..c0b5d500b46ab 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -517,7 +517,7 @@ async def api_fn(message, history, *args, **kwargs): self.fake_api_btn.click( api_fn, [self.textbox, self.chatbot_state] + self.additional_inputs, - [self.textbox, self.chatbot_state], + [self.fake_response_textbox, self.chatbot_state], api_name="chat", concurrency_limit=cast( Union[int, Literal["default"], None], self.concurrency_limit @@ -697,65 +697,6 @@ async def _stream_fn( history_with_input[-1] = response # type: ignore yield history_with_input - async def _api_submit_fn( - self, - message: str, - history: TupleFormat | list[MessageDict], - request: Request, - *args, - ) -> tuple[str, TupleFormat | list[MessageDict]]: - inputs, _, _ = special_args( - self.fn, inputs=[message, history, *args], request=request - ) - - if self.is_async: - response = await self.fn(*inputs) - else: - response = await anyio.to_thread.run_sync( - self.fn, *inputs, limiter=self.limiter - ) - if self.type == "tuples": - history.append([message, response]) # type: ignore - else: - new_response = self.response_as_dict(response) - history.extend([{"role": "user", "content": message}, new_response]) # type: ignore - return response, history - - async def _api_stream_fn( - self, message: str, history: list[list[str | None]], request: Request, *args - ) -> AsyncGenerator: - inputs, _, _ = special_args( - self.fn, inputs=[message, history, *args], request=request - ) - if self.is_async: - generator = self.fn(*inputs) - else: - generator = await anyio.to_thread.run_sync( - self.fn, *inputs, limiter=self.limiter - ) - generator = SyncToAsyncIterator(generator, self.limiter) - try: - first_response = await async_iteration(generator) - if self.type == "tuples": - yield first_response, history + [[message, first_response]] - else: - first_response = self.response_as_dict(first_response) - yield ( - first_response, - history + [{"role": "user", "content": message}, first_response], - ) - except StopIteration: - yield None, history + [[message, None]] - async for response in generator: - if self.type == "tuples": - yield response, history + [[message, response]] - else: - new_response = self.response_as_dict(response) - yield ( - new_response, - history + [{"role": "user", "content": message}, new_response], - ) - async def _examples_fn( self, message: str, *args ) -> TupleFormat | list[MessageDict]: diff --git a/test/test_chat_interface.py b/test/test_chat_interface.py index 589c3440fd26b..873b2b3cece02 100644 --- a/test/test_chat_interface.py +++ b/test/test_chat_interface.py @@ -301,3 +301,17 @@ def test_streaming_api_with_additional_inputs(self, type, connect): "robot ", "robot h", ] + + @pytest.mark.parametrize("type", ["tuples", "messages"]) + def test_multimodal_api(self, type, connect): + def double_multimodal(msg, history): + return msg["text"] + " " + msg["text"] + + chatbot = gr.ChatInterface( + double_multimodal, + type=type, + multimodal=True, + ) + with connect(chatbot) as client: + result = client.predict({"text": "hello", "files": []}, api_name="/chat") + assert result == "hello hello"