diff --git a/.changeset/olive-loops-look.md b/.changeset/olive-loops-look.md new file mode 100644 index 0000000000000..c38de2373ec78 --- /dev/null +++ b/.changeset/olive-loops-look.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Allow `gr.Request` to work with ZeroGPU diff --git a/gradio/route_utils.py b/gradio/route_utils.py index 4aa70f6491c20..f5da757b2a36a 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -6,6 +6,7 @@ import hmac import json import os +import pickle import re import shutil import sys @@ -18,6 +19,7 @@ from tempfile import NamedTemporaryFile, _TemporaryFileWrapper from typing import ( TYPE_CHECKING, + Any, AsyncContextManager, AsyncGenerator, BinaryIO, @@ -125,15 +127,18 @@ class Request: query parameters and other information about the request from within the prediction function. The class is a thin wrapper around the fastapi.Request class. Attributes of this class include: `headers`, `client`, `query_params`, `session_hash`, and `path_params`. If - auth is enabled, the `username` attribute can be used to get the logged in user. + auth is enabled, the `username` attribute can be used to get the logged in user. In some environments, + the dict-like attributes (e.g. `requests.headers`, `requests.query_params`) of this class are automatically + converted to to dictionaries, so we recommend converting them to dictionaries before accessing + attributes for consistent behavior in different environments. Example: import gradio as gr def echo(text, request: gr.Request): if request: - print("Request headers dictionary:", request.headers) - print("IP address:", request.client.host) + print("Request headers dictionary:", dict(request.headers)) print("Query parameters:", dict(request.query_params)) - print("Session hash:", request.session_hash) + print("IP address:", request.client.host) + print("Gradio session hash:", request.session_hash) return text io = gr.Interface(echo, "textbox", "textbox").launch() Demos: request_ip_headers @@ -156,8 +161,8 @@ def __init__( """ self.request = request self.username = username - self.session_hash = session_hash - self.kwargs: dict = kwargs + self.session_hash: str | None = session_hash + self.kwargs: dict[str, Any] = kwargs def dict_to_obj(self, d): if isinstance(d, dict): @@ -165,7 +170,7 @@ def dict_to_obj(self, d): else: return d - def __getattr__(self, name): + def __getattr__(self, name: str): if self.request: return self.dict_to_obj(getattr(self.request, name)) else: @@ -177,6 +182,34 @@ def __getattr__(self, name): ) from ke return self.dict_to_obj(obj) + def __getstate__(self) -> dict[str, Any]: + self.kwargs.update( + { + "headers": dict(getattr(self, "headers", {})), + "query_params": dict(getattr(self, "query_params", {})), + "cookies": dict(getattr(self, "cookies", {})), + "path_params": dict(getattr(self, "path_params", {})), + "client": { + "host": getattr(self, "client", {}) and self.client.host, + "port": getattr(self, "client", {}) and self.client.port, + }, + "url": getattr(self, "url", ""), + } + ) + if request_state := hasattr(self, "state"): + try: + pickle.dumps(request_state) + self.kwargs["request_state"] = request_state + except pickle.PicklingError: + pass + self.request = None + return self.__dict__ + + def __setstate__(self, state: dict[str, Any]): + if request_state := state.pop("request_state", None): + self.state = request_state + self.__dict__ = state + class FnIndexInferError(Exception): pass diff --git a/test/test_routes.py b/test/test_routes.py index 8de8f7d38995e..0a01e37d5c8ad 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -2,6 +2,7 @@ import functools import os +import pickle import tempfile import time from contextlib import asynccontextmanager, closing @@ -818,6 +819,32 @@ def identity(name, request: gr.Request): output = dict(response.json()) assert output["data"] == ["test"] + def test_request_is_pickleable(self): + """ + For ZeroGPU, we need to ensure that the gr.Request object is pickle-able. + """ + + def identity(name, request: gr.Request): + pickled = pickle.dumps(request) + unpickled = pickle.loads(pickled) + assert request.client.host == unpickled.client.host + assert request.client.port == unpickled.client.port + assert dict(request.query_params) == dict(unpickled.query_params) + assert request.query_params["a"] == unpickled.query_params["a"] + assert dict(request.headers) == dict(unpickled.headers) + assert request.username == unpickled.username + return name + + app, _, _ = gr.Interface(identity, "textbox", "textbox").launch( + prevent_thread_lock=True, + ) + client = TestClient(app) + + response = client.post("/api/predict?a=b", json={"data": ["test"]}) + assert response.status_code == 200 + output = dict(response.json()) + assert output["data"] == ["test"] + def test_predict_route_is_blocked_if_api_open_false(): io = Interface(lambda x: x, "text", "text", examples=[["freddy"]]).queue(