Skip to content

Commit

Permalink
Fix multimodal chatinterface api bug (gradio-app#9054)
Browse files Browse the repository at this point in the history
* fix

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot authored Aug 8, 2024
1 parent f29aef4 commit 9fa635a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 60 deletions.
5 changes: 5 additions & 0 deletions .changeset/mighty-maps-double.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Fix multimodal chatinterface api bug
61 changes: 1 addition & 60 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ async def api_fn(message, history, *args, **kwargs):
self.fake_api_btn.click(
api_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state],
[self.fake_response_textbox, self.chatbot_state],
api_name="chat",
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
Expand Down Expand Up @@ -697,65 +697,6 @@ async def _stream_fn(
history_with_input[-1] = response # type: ignore
yield history_with_input

async def _api_submit_fn(
self,
message: str,
history: TupleFormat | list[MessageDict],
request: Request,
*args,
) -> tuple[str, TupleFormat | list[MessageDict]]:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)

if self.is_async:
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
if self.type == "tuples":
history.append([message, response]) # type: ignore
else:
new_response = self.response_as_dict(response)
history.extend([{"role": "user", "content": message}, new_response]) # type: ignore
return response, history

async def _api_stream_fn(
self, message: str, history: list[list[str | None]], request: Request, *args
) -> AsyncGenerator:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)
if self.is_async:
generator = self.fn(*inputs)
else:
generator = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await async_iteration(generator)
if self.type == "tuples":
yield first_response, history + [[message, first_response]]
else:
first_response = self.response_as_dict(first_response)
yield (
first_response,
history + [{"role": "user", "content": message}, first_response],
)
except StopIteration:
yield None, history + [[message, None]]
async for response in generator:
if self.type == "tuples":
yield response, history + [[message, response]]
else:
new_response = self.response_as_dict(response)
yield (
new_response,
history + [{"role": "user", "content": message}, new_response],
)

async def _examples_fn(
self, message: str, *args
) -> TupleFormat | list[MessageDict]:
Expand Down
14 changes: 14 additions & 0 deletions test/test_chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,17 @@ def test_streaming_api_with_additional_inputs(self, type, connect):
"robot ",
"robot h",
]

@pytest.mark.parametrize("type", ["tuples", "messages"])
def test_multimodal_api(self, type, connect):
def double_multimodal(msg, history):
return msg["text"] + " " + msg["text"]

chatbot = gr.ChatInterface(
double_multimodal,
type=type,
multimodal=True,
)
with connect(chatbot) as client:
result = client.predict({"text": "hello", "files": []}, api_name="/chat")
assert result == "hello hello"

0 comments on commit 9fa635a

Please sign in to comment.