From 39791eb186d3a4ce82c8c27979a28311c37a4067 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 1 May 2024 21:48:49 -0700 Subject: [PATCH] Convert sse calls in client from async to sync (#8182) * convert sse calls in client from async to sync * add changeset * more sync * lint * more sync * fix threadpool * fix timeouts * reuse executor * lint --------- Co-authored-by: gradio-pr-bot --- .changeset/great-poets-visit.md | 6 ++ client/python/gradio_client/client.py | 49 ++++----- client/python/gradio_client/utils.py | 144 +++++++++++++------------- client/python/test/test_client.py | 2 +- 4 files changed, 99 insertions(+), 102 deletions(-) create mode 100644 .changeset/great-poets-visit.md diff --git a/.changeset/great-poets-visit.md b/.changeset/great-poets-visit.md new file mode 100644 index 0000000000000..d07709562b8c7 --- /dev/null +++ b/.changeset/great-poets-visit.md @@ -0,0 +1,6 @@ +--- +"gradio": patch +"gradio_client": patch +--- + +fix:Convert sse calls in client from async to sync diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 7581f7042b9fd..2f9843565f8f6 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -225,21 +225,21 @@ def _stream_heartbeat(self): except httpx.TransportError: return - async def stream_messages( + def stream_messages( self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"] ) -> None: try: - async with httpx.AsyncClient( + with httpx.Client( timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify ) as client: - async with client.stream( + with client.stream( "GET", self.sse_url, params={"session_hash": self.session_hash}, headers=self.headers, cookies=self.cookies, ) as response: - async for line in response.aiter_lines(): + for line in response.iter_lines(): line = line.rstrip("\n") if not len(line): continue @@ -276,14 +276,13 @@ async def stream_messages( traceback.print_exc() raise e - async def send_data(self, data, hash_data, protocol): - async with httpx.AsyncClient(verify=self.ssl_verify) as client: - req = await client.post( - self.sse_data_url, - json={**data, **hash_data}, - headers=self.headers, - cookies=self.cookies, - ) + def send_data(self, data, hash_data, protocol): + req = httpx.post( + self.sse_data_url, + json={**data, **hash_data}, + headers=self.headers, + cookies=self.cookies, + ) if req.status_code == 503: raise QueueError("Queue is full! Please try again.") req.raise_for_status() @@ -294,7 +293,7 @@ async def send_data(self, data, hash_data, protocol): self.stream_open = True def open_stream(): - return utils.synchronize_async(self.stream_messages, protocol) + return self.stream_messages(protocol) def close_stream(_): self.stream_open = False @@ -1119,18 +1118,12 @@ def _predict(*data) -> tuple: } if self.protocol == "sse": - result = utils.synchronize_async( - self._sse_fn_v0, data, hash_data, helper - ) + result = self._sse_fn_v0(data, hash_data, helper) # type: ignore elif self.protocol in ("sse_v1", "sse_v2", "sse_v2.1", "sse_v3"): - event_id = utils.synchronize_async( - self.client.send_data, data, hash_data, self.protocol - ) + event_id = self.client.send_data(data, hash_data, self.protocol) self.client.pending_event_ids.add(event_id) self.client.pending_messages_per_event[event_id] = [] - result = utils.synchronize_async( - self._sse_fn_v1plus, helper, event_id, self.protocol - ) + result = self._sse_fn_v1plus(helper, event_id, self.protocol) else: raise ValueError(f"Unsupported protocol: {self.protocol}") @@ -1290,11 +1283,11 @@ def _download_file(self, x: dict) -> str: shutil.move(temp_dir / Path(url_path).name, dest) return str(dest.resolve()) - async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator): - async with httpx.AsyncClient( + def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator): + with httpx.Client( timeout=httpx.Timeout(timeout=None), verify=self.client.ssl_verify ) as client: - return await utils.get_pred_from_sse_v0( + return utils.get_pred_from_sse_v0( client, data, hash_data, @@ -1304,15 +1297,16 @@ async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator): self.client.headers, self.client.cookies, self.client.ssl_verify, + self.client.executor, ) - async def _sse_fn_v1plus( + def _sse_fn_v1plus( self, helper: Communicator, event_id: str, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"], ): - return await utils.get_pred_from_sse_v1plus( + return utils.get_pred_from_sse_v1plus( helper, self.client.headers, self.client.cookies, @@ -1320,6 +1314,7 @@ async def _sse_fn_v1plus( event_id, protocol, self.client.ssl_verify, + self.client.executor, ) diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index 83eb0e9f2031c..260cafd113152 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -2,6 +2,7 @@ import asyncio import base64 +import concurrent.futures import copy import json import mimetypes @@ -10,8 +11,8 @@ import secrets import shutil import tempfile +import time import warnings -from concurrent.futures import CancelledError from dataclasses import dataclass, field from datetime import datetime from enum import Enum @@ -243,6 +244,7 @@ class Communicator: reset_url: str should_cancel: bool = False event_id: str | None = None + thread_complete: bool = False ######################## @@ -266,7 +268,7 @@ def probe_url(possible_url: str) -> bool: headers = {"User-Agent": "gradio (https://gradio.app/; gradio-team@huggingface.co)"} try: with httpx.Client() as client: - head_request = client.head(possible_url, headers=headers) + head_request = httpx.head(possible_url, headers=headers) if head_request.status_code == 405: return client.get(possible_url, headers=headers).is_success return head_request.is_success @@ -311,7 +313,7 @@ async def get_pred_from_ws( # otherwise will get nasty warning in console task.cancel() await asyncio.gather(task, reset, return_exceptions=True) - raise CancelledError() + raise concurrent.futures.CancelledError() # Need to suspend this coroutine so that task actually runs await asyncio.sleep(0.01) msg = task.result() @@ -348,8 +350,8 @@ async def get_pred_from_ws( return resp["output"] -async def get_pred_from_sse_v0( - client: httpx.AsyncClient, +def get_pred_from_sse_v0( + client: httpx.Client, data: dict, hash_data: dict, helper: Communicator, @@ -358,40 +360,36 @@ async def get_pred_from_sse_v0( headers: dict[str, str], cookies: dict[str, str] | None, ssl_verify: bool, + executor: concurrent.futures.ThreadPoolExecutor, ) -> dict[str, Any] | None: - done, pending = await asyncio.wait( - [ - asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)), - asyncio.create_task( - stream_sse_v0( - client, - data, - hash_data, - helper, - sse_url, - sse_data_url, - headers, - cookies, - ) - ), - ], - return_when=asyncio.FIRST_COMPLETED, + helper.thread_complete = False + future_cancel = executor.submit( + check_for_cancel, helper, headers, cookies, ssl_verify ) - - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + future_sse = executor.submit( + stream_sse_v0, + client, + data, + hash_data, + helper, + sse_url, + sse_data_url, + headers, + cookies, + ) + done, _ = concurrent.futures.wait( + [future_cancel, future_sse], # type: ignore + return_when=concurrent.futures.FIRST_COMPLETED, + ) + helper.thread_complete = True if len(done) != 1: raise ValueError(f"Did not expect {len(done)} tasks to be done.") - for task in done: - return task.result() + for future in done: + return future.result() -async def get_pred_from_sse_v1plus( +def get_pred_from_sse_v1plus( helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None, @@ -399,59 +397,56 @@ async def get_pred_from_sse_v1plus( event_id: str, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"], ssl_verify: bool, + executor: concurrent.futures.ThreadPoolExecutor, ) -> dict[str, Any] | None: - done, pending = await asyncio.wait( - [ - asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)), - asyncio.create_task( - stream_sse_v1plus( - helper, pending_messages_per_event, event_id, protocol - ) - ), - ], - return_when=asyncio.FIRST_COMPLETED, + helper.thread_complete = False + future_cancel = executor.submit( + check_for_cancel, helper, headers, cookies, ssl_verify ) - - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + future_sse = executor.submit( + stream_sse_v1plus, helper, pending_messages_per_event, event_id, protocol + ) + done, _ = concurrent.futures.wait( + [future_cancel, future_sse], # type: ignore + return_when=concurrent.futures.FIRST_COMPLETED, + ) + helper.thread_complete = True if len(done) != 1: raise ValueError(f"Did not expect {len(done)} tasks to be done.") - for task in done: - exception = task.exception() + for future in done: + exception = future.exception() if exception: raise exception - return task.result() + return future.result() -async def check_for_cancel( +def check_for_cancel( helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None, ssl_verify: bool, ): while True: - await asyncio.sleep(0.05) + time.sleep(0.05) with helper.lock: if helper.should_cancel: break + if helper.thread_complete: + raise concurrent.futures.CancelledError() if helper.event_id: - async with httpx.AsyncClient(ssl_verify=ssl_verify) as http: - await http.post( - helper.reset_url, - json={"event_id": helper.event_id}, - headers=headers, - cookies=cookies, - ) - raise CancelledError() + httpx.post( + helper.reset_url, + json={"event_id": helper.event_id}, + headers=headers, + cookies=cookies, + verify=ssl_verify, + ) + raise concurrent.futures.CancelledError() -async def stream_sse_v0( - client: httpx.AsyncClient, +def stream_sse_v0( + client: httpx.Client, data: dict, hash_data: dict, helper: Communicator, @@ -461,14 +456,14 @@ async def stream_sse_v0( cookies: dict[str, str] | None, ) -> dict[str, Any]: try: - async with client.stream( + with client.stream( "GET", sse_url, params=hash_data, headers=headers, cookies=cookies, ) as response: - async for line in response.aiter_lines(): + for line in response.iter_lines(): line = line.rstrip("\n") if len(line) == 0: continue @@ -497,13 +492,14 @@ async def stream_sse_v0( result = [e] helper.job.outputs.append(result) helper.job.latest_status = status_update - + if helper.thread_complete: + raise concurrent.futures.CancelledError() if resp["msg"] == "queue_full": raise QueueError("Queue is full! Please try again.") elif resp["msg"] == "send_data": event_id = resp["event_id"] helper.event_id = event_id - req = await client.post( + req = client.post( sse_data_url, json={"event_id": event_id, **data, **hash_data}, headers=headers, @@ -515,11 +511,11 @@ async def stream_sse_v0( else: raise ValueError(f"Unexpected message: {line}") raise ValueError("Did not receive process_completed message.") - except asyncio.CancelledError: + except concurrent.futures.CancelledError: raise -async def stream_sse_v1plus( +def stream_sse_v1plus( helper: Communicator, pending_messages_per_event: dict[str, list[Message | None]], event_id: str, @@ -533,11 +529,11 @@ async def stream_sse_v1plus( if len(pending_messages) > 0: msg = pending_messages.pop(0) else: - await asyncio.sleep(0.05) + time.sleep(0.05) continue - if msg is None: - raise CancelledError() + if msg is None or helper.thread_complete: + raise concurrent.futures.CancelledError() with helper.lock: log_message = None @@ -586,7 +582,7 @@ async def stream_sse_v1plus( elif msg["msg"] == ServerMessage.server_stopped: raise ValueError("Server stopped.") - except asyncio.CancelledError: + except concurrent.futures.CancelledError: raise diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index a5af511699f66..e344290576fb0 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -1240,7 +1240,7 @@ def test_download_private_file(self, gradio_temp_dir): src="gradio/zip_files", ) url_path = "https://gradio-tests-not-actually-private-spacev4-sse.hf.space/file=lion.jpg" - file = client.endpoints[0]._upload_file(url_path) # type: ignore + file = client.endpoints[0]._upload_file(url_path, 0) # type: ignore assert file["path"].endswith(".jpg") @pytest.mark.flaky