Skip to content

Commit

Permalink
Allow gr.Request to work with ZeroGPU (#9148)
Browse files Browse the repository at this point in the history
* pickleable

* add test

* routes

* add changeset

* one mroe test

* format

* rename test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot authored Aug 19, 2024
1 parent 04b7d32 commit 8715f10
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 7 deletions.
5 changes: 5 additions & 0 deletions .changeset/olive-loops-look.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Allow `gr.Request` to work with ZeroGPU
47 changes: 40 additions & 7 deletions gradio/route_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hmac
import json
import os
import pickle
import re
import shutil
import sys
Expand All @@ -18,6 +19,7 @@
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
AsyncGenerator,
BinaryIO,
Expand Down Expand Up @@ -125,15 +127,18 @@ class Request:
query parameters and other information about the request from within the prediction
function. The class is a thin wrapper around the fastapi.Request class. Attributes
of this class include: `headers`, `client`, `query_params`, `session_hash`, and `path_params`. If
auth is enabled, the `username` attribute can be used to get the logged in user.
auth is enabled, the `username` attribute can be used to get the logged in user. In some environments,
the dict-like attributes (e.g. `requests.headers`, `requests.query_params`) of this class are automatically
converted to to dictionaries, so we recommend converting them to dictionaries before accessing
attributes for consistent behavior in different environments.
Example:
import gradio as gr
def echo(text, request: gr.Request):
if request:
print("Request headers dictionary:", request.headers)
print("IP address:", request.client.host)
print("Request headers dictionary:", dict(request.headers))
print("Query parameters:", dict(request.query_params))
print("Session hash:", request.session_hash)
print("IP address:", request.client.host)
print("Gradio session hash:", request.session_hash)
return text
io = gr.Interface(echo, "textbox", "textbox").launch()
Demos: request_ip_headers
Expand All @@ -156,16 +161,16 @@ def __init__(
"""
self.request = request
self.username = username
self.session_hash = session_hash
self.kwargs: dict = kwargs
self.session_hash: str | None = session_hash
self.kwargs: dict[str, Any] = kwargs

def dict_to_obj(self, d):
if isinstance(d, dict):
return json.loads(json.dumps(d), object_hook=Obj)
else:
return d

def __getattr__(self, name):
def __getattr__(self, name: str):
if self.request:
return self.dict_to_obj(getattr(self.request, name))
else:
Expand All @@ -177,6 +182,34 @@ def __getattr__(self, name):
) from ke
return self.dict_to_obj(obj)

def __getstate__(self) -> dict[str, Any]:
self.kwargs.update(
{
"headers": dict(getattr(self, "headers", {})),
"query_params": dict(getattr(self, "query_params", {})),
"cookies": dict(getattr(self, "cookies", {})),
"path_params": dict(getattr(self, "path_params", {})),
"client": {
"host": getattr(self, "client", {}) and self.client.host,
"port": getattr(self, "client", {}) and self.client.port,
},
"url": getattr(self, "url", ""),
}
)
if request_state := hasattr(self, "state"):
try:
pickle.dumps(request_state)
self.kwargs["request_state"] = request_state
except pickle.PicklingError:
pass
self.request = None
return self.__dict__

def __setstate__(self, state: dict[str, Any]):
if request_state := state.pop("request_state", None):
self.state = request_state
self.__dict__ = state


class FnIndexInferError(Exception):
pass
Expand Down
27 changes: 27 additions & 0 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import os
import pickle
import tempfile
import time
from contextlib import asynccontextmanager, closing
Expand Down Expand Up @@ -818,6 +819,32 @@ def identity(name, request: gr.Request):
output = dict(response.json())
assert output["data"] == ["test"]

def test_request_is_pickleable(self):
"""
For ZeroGPU, we need to ensure that the gr.Request object is pickle-able.
"""

def identity(name, request: gr.Request):
pickled = pickle.dumps(request)
unpickled = pickle.loads(pickled)
assert request.client.host == unpickled.client.host
assert request.client.port == unpickled.client.port
assert dict(request.query_params) == dict(unpickled.query_params)
assert request.query_params["a"] == unpickled.query_params["a"]
assert dict(request.headers) == dict(unpickled.headers)
assert request.username == unpickled.username
return name

app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
prevent_thread_lock=True,
)
client = TestClient(app)

response = client.post("/api/predict?a=b", json={"data": ["test"]})
assert response.status_code == 200
output = dict(response.json())
assert output["data"] == ["test"]


def test_predict_route_is_blocked_if_api_open_false():
io = Interface(lambda x: x, "text", "text", examples=[["freddy"]]).queue(
Expand Down

0 comments on commit 8715f10

Please sign in to comment.