Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert sse calls in client from async to sync #8182

Merged
merged 10 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more sync
  • Loading branch information
abidlabs committed May 1, 2024
commit 70d1d6ac89407e500747ca1cf3a90957140c7bd1
21 changes: 9 additions & 12 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,12 @@ def stream_messages(
raise e

def send_data(self, data, hash_data, protocol):
with httpx.Client(verify=self.ssl_verify) as client:
req = client.post(
self.sse_data_url,
json={**data, **hash_data},
headers=self.headers,
cookies=self.cookies,
)
req = httpx.post(
self.sse_data_url,
json={**data, **hash_data},
headers=self.headers,
cookies=self.cookies,
)
if req.status_code == 503:
raise QueueError("Queue is full! Please try again.")
req.raise_for_status()
Expand Down Expand Up @@ -1124,9 +1123,7 @@ def _predict(*data) -> tuple:
event_id = self.client.send_data(data, hash_data, self.protocol)
self.client.pending_event_ids.add(event_id)
self.client.pending_messages_per_event[event_id] = []
result = utils.synchronize_async(
self._sse_fn_v1plus, helper, event_id, self.protocol
)
result = self._sse_fn_v1plus(helper, event_id, self.protocol)
else:
raise ValueError(f"Unsupported protocol: {self.protocol}")

Expand Down Expand Up @@ -1302,13 +1299,13 @@ def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
self.client.ssl_verify,
)

async def _sse_fn_v1plus(
def _sse_fn_v1plus(
self,
helper: Communicator,
event_id: str,
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"],
):
return await utils.get_pred_from_sse_v1plus(
return utils.get_pred_from_sse_v1plus(
helper,
self.client.headers,
self.client.cookies,
Expand Down
109 changes: 65 additions & 44 deletions client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import base64
import concurrent.futures
import copy
import json
import mimetypes
Expand All @@ -10,8 +11,8 @@
import secrets
import shutil
import tempfile
import time
import warnings
from concurrent.futures import CancelledError
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -266,7 +267,7 @@ def probe_url(possible_url: str) -> bool:
headers = {"User-Agent": "gradio (https://gradio.app/; gradio-team@huggingface.co)"}
try:
with httpx.Client() as client:
head_request = client.head(possible_url, headers=headers)
head_request = httpx.head(possible_url, headers=headers)
if head_request.status_code == 405:
return client.get(possible_url, headers=headers).is_success
return head_request.is_success
Expand Down Expand Up @@ -311,7 +312,7 @@ async def get_pred_from_ws(
# otherwise will get nasty warning in console
task.cancel()
await asyncio.gather(task, reset, return_exceptions=True)
raise CancelledError()
raise concurrent.futures.CancelledError()
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
# Need to suspend this coroutine so that task actually runs
await asyncio.sleep(0.01)
msg = task.result()
Expand Down Expand Up @@ -359,8 +360,12 @@ def get_pred_from_sse_v0(
cookies: dict[str, str] | None,
ssl_verify: bool,
) -> dict[str, Any] | None:
try:
result = stream_sse_v0(
with concurrent.futures.ThreadPoolExecutor() as executor:
future_cancel = executor.submit(
check_for_cancel, helper, headers, cookies, ssl_verify
)
future_sse = executor.submit(
stream_sse_v0,
client,
data,
hash_data,
Expand All @@ -370,15 +375,28 @@ def get_pred_from_sse_v0(
headers,
cookies,
)
except Exception as e:
if check_for_cancel(helper, headers, cookies, ssl_verify):
return None
else:
raise e
return result
done, pending = concurrent.futures.wait(
[future_cancel, future_sse], return_when=concurrent.futures.FIRST_COMPLETED
)

for future in pending:
future.cancel()

concurrent.futures.wait(pending)

for future in pending:
try:
future.result()
except concurrent.futures.CancelledError:
pass

async def get_pred_from_sse_v1plus(
if len(done) != 1:
raise ValueError(f"Did not expect {len(done)} tasks to be done.")
for future in done:
return future.result()


def get_pred_from_sse_v1plus(
helper: Communicator,
headers: dict[str, str],
cookies: dict[str, str] | None,
Expand All @@ -387,54 +405,57 @@ async def get_pred_from_sse_v1plus(
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
ssl_verify: bool,
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)),
asyncio.create_task(
stream_sse_v1plus(
helper, pending_messages_per_event, event_id, protocol
)
),
],
return_when=asyncio.FIRST_COMPLETED,
)
with concurrent.futures.ThreadPoolExecutor() as executor:
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
future_cancel = executor.submit(
check_for_cancel, helper, headers, cookies, ssl_verify
)
future_sse = executor.submit(
stream_sse_v1plus, helper, pending_messages_per_event, event_id, protocol
)
done, pending = concurrent.futures.wait(
[future_cancel, future_sse], return_when=concurrent.futures.FIRST_COMPLETED
)

for task in pending:
task.cancel()
for future in pending:
future.cancel()

concurrent.futures.wait(pending)

for future in pending:
try:
await task
except asyncio.CancelledError:
future.result()
except concurrent.futures.CancelledError:
pass

if len(done) != 1:
raise ValueError(f"Did not expect {len(done)} tasks to be done.")
for task in done:
exception = task.exception()
for future in done:
exception = future.exception()
if exception:
raise exception
return task.result()
return future.result()


async def check_for_cancel(
def check_for_cancel(
helper: Communicator,
headers: dict[str, str],
cookies: dict[str, str] | None,
ssl_verify: bool,
):
while True:
await asyncio.sleep(0.05)
time.sleep(0.05)
with helper.lock:
if helper.should_cancel:
break
if helper.event_id:
async with httpx.AsyncClient(ssl_verify=ssl_verify) as http:
await http.post(
helper.reset_url,
json={"event_id": helper.event_id},
headers=headers,
cookies=cookies,
)
raise CancelledError()
httpx.post(
helper.reset_url,
json={"event_id": helper.event_id},
headers=headers,
cookies=cookies,
verify=ssl_verify,
)
raise concurrent.futures.CancelledError()


def stream_sse_v0(
Expand Down Expand Up @@ -506,7 +527,7 @@ def stream_sse_v0(
raise


async def stream_sse_v1plus(
def stream_sse_v1plus(
helper: Communicator,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
Expand All @@ -520,11 +541,11 @@ async def stream_sse_v1plus(
if len(pending_messages) > 0:
msg = pending_messages.pop(0)
else:
await asyncio.sleep(0.05)
time.sleep(0.05)
continue

if msg is None:
raise CancelledError()
raise concurrent.futures.CancelledError()

with helper.lock:
log_message = None
Expand Down Expand Up @@ -573,7 +594,7 @@ async def stream_sse_v1plus(
elif msg["msg"] == ServerMessage.server_stopped:
raise ValueError("Server stopped.")

except asyncio.CancelledError:
except concurrent.futures.CancelledError:
raise


Expand Down