Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set root correctly for Gradio apps that are deployed behind reverse proxies #7411

Merged
merged 29 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b2e3ba1
testing
abidlabs Feb 13, 2024
0c29f32
add changeset
gradio-pr-bot Feb 13, 2024
d23f835
test
abidlabs Feb 14, 2024
ecf55d3
Merge branch 'root-path' of github.com:gradio-app/gradio into root-path
abidlabs Feb 14, 2024
e96ed9f
backend
abidlabs Feb 14, 2024
8d979f0
fix
abidlabs Feb 14, 2024
8660453
add unit tests
abidlabs Feb 14, 2024
29c575b
testing
abidlabs Feb 14, 2024
f0bbdc4
remove check
abidlabs Feb 14, 2024
3e1083d
add changeset
gradio-pr-bot Feb 14, 2024
71d5283
trying something
abidlabs Feb 14, 2024
0964ca7
Merge branch 'root-path' of github.com:gradio-app/gradio into root-path
abidlabs Feb 14, 2024
c28abee
add changeset
gradio-pr-bot Feb 14, 2024
b80d18e
override
abidlabs Feb 14, 2024
b082d13
Merge branch 'root-path' of github.com:gradio-app/gradio into root-path
abidlabs Feb 14, 2024
74e3f2b
add changeset
gradio-pr-bot Feb 14, 2024
4c8e7f7
fix
abidlabs Feb 14, 2024
70c28d0
Merge branch 'root-path' of github.com:gradio-app/gradio into root-path
abidlabs Feb 14, 2024
5f26128
fix
abidlabs Feb 14, 2024
3e507d1
clean
abidlabs Feb 14, 2024
1fcb373
lint
abidlabs Feb 14, 2024
a0cdefc
route utils
abidlabs Feb 14, 2024
bc3ef5b
Merge branch 'main' into root-path
abidlabs Feb 14, 2024
9b8810f
add changeset
gradio-pr-bot Feb 14, 2024
060dfd5
changes
abidlabs Feb 14, 2024
a7b8ea8
add changeset
gradio-pr-bot Feb 14, 2024
7382031
test
abidlabs Feb 14, 2024
57de921
Merge branch 'root-path' of github.com:gradio-app/gradio into root-path
abidlabs Feb 14, 2024
335daba
revert testing
abidlabs Feb 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/tricky-coins-sniff.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/app": patch
"@gradio/client": patch
"gradio": patch
---

fix:Set `root` correctly for Gradio apps that are deployed behind reverse proxies
3 changes: 3 additions & 0 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ export function api_factory(

async function config_success(_config: Config): Promise<client_return> {
config = _config;
if (window.location.protocol === "https:") {
config.root = config.root.replace("http://", "https://");
}
api_map = map_names_to_ids(_config?.dependencies || []);
if (config.auth_required) {
return {
Expand Down
16 changes: 11 additions & 5 deletions gradio/route_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,18 +261,24 @@ async def call_process_api(
return output


def get_root_url(request: fastapi.Request) -> str:
def get_root_url(
request: fastapi.Request, route_path: str, root_path: str | None
) -> str:
"""
Gets the root url of the request, stripping off any query parameters and trailing slashes.
Also ensures that the root url is https if the request is https.
Gets the root url of the request, stripping off any query parameters, the route_path, and trailing slashes.
Also ensures that the root url is https if the request is https. If root_path is provided, it is appended to the root url.
The final root url will not have a trailing slash.
"""
root_url = str(request.url)
root_url = httpx.URL(root_url)
root_url = root_url.copy_with(query=None)
root_url = str(root_url)
root_url = str(root_url).rstrip("/")
if request.headers.get("x-forwarded-proto") == "https":
root_url = root_url.replace("http://", "https://")
return root_url.rstrip("/")
route_path = route_path.rstrip("/")
if len(route_path) > 0:
root_url = root_url[: -len(route_path)]
return (root_url.rstrip("/") + (root_path or "")).rstrip("/")


def _user_safe_decode(src: bytes, codec: str) -> str:
Expand Down
16 changes: 12 additions & 4 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,9 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()):
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
mimetypes.add_type("application/javascript", ".js")
blocks = app.get_blocks()
root_path = route_utils.get_root_url(request)
root_path = route_utils.get_root_url(
request=request, route_path="/", root_path=app.root_path
)
if app.auth is None or user is not None:
config = copy.deepcopy(app.get_blocks().config)
config["root"] = root_path
Expand Down Expand Up @@ -353,7 +355,9 @@ def api_info():
@app.get("/config", dependencies=[Depends(login_check)])
def get_config(request: fastapi.Request):
config = copy.deepcopy(app.get_blocks().config)
root_path = route_utils.get_root_url(request)[: -len("/config")]
root_path = route_utils.get_root_url(
request=request, route_path="/config", root_path=app.root_path
)
config["root"] = root_path
config = add_root_url(config, root_path)
return config
Expand Down Expand Up @@ -570,7 +574,9 @@ async def predict(
content={"error": str(error) if show_error else None},
status_code=500,
)
root_path = route_utils.get_root_url(request)[: -len(f"/api/{api_name}")]
root_path = route_utils.get_root_url(
request=request, route_path=f"/api/{api_name}", root_path=app.root_path
)
output = add_root_url(output, root_path)
return output

Expand All @@ -580,7 +586,9 @@ async def queue_data(
session_hash: str,
):
blocks = app.get_blocks()
root_path = route_utils.get_root_url(request)[: -len("/queue/data")]
root_path = route_utils.get_root_url(
request=request, route_path="/queue/data", root_path=app.root_path
)

async def sse_stream(request: fastapi.Request):
try:
Expand Down
3 changes: 3 additions & 0 deletions js/app/src/Index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@
status_callback: handle_status
});
config = app.config;
if (window.location.protocol === "https:") {
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
config.root = config.root.replace("http://", "https://");
}
window.__gradio_space__ = config.space_id;

status = {
Expand Down
51 changes: 49 additions & 2 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
import pytest
import starlette.routing
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from gradio_client import media_data

Expand All @@ -25,7 +25,7 @@
routes,
wasm_utils,
)
from gradio.route_utils import FnIndexInferError
from gradio.route_utils import FnIndexInferError, get_root_url


@pytest.fixture()
Expand Down Expand Up @@ -862,3 +862,50 @@ def test_component_server_endpoints(connect):
},
)
assert fail_req.status_code == 404


@pytest.mark.parametrize(
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
"request_url, route_path, root_path, expected_root_url",
[
("http://localhost:7860/", "/", None, "http://localhost:7860"),
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
(
"http://localhost:7860/demo/test",
"/demo/test",
None,
"http://localhost:7860",
),
(
"http://localhost:7860/demo/test/",
"/demo/test",
None,
"http://localhost:7860",
),
(
"http://localhost:7860/demo/test?query=1",
"/demo/test",
None,
"http://localhost:7860",
),
(
"http://localhost:7860/demo/test?query=1",
"/demo/test",
"/gradio",
"http://localhost:7860/gradio",
),
(
"http://localhost:7860/demo/test?query=1",
"/demo/test",
"/gradio/",
"http://localhost:7860/gradio",
),
(
"https://localhost:7860/demo/test?query=1",
"/demo/test",
"/gradio/",
"https://localhost:7860/gradio",
),
],
)
def test_get_root_url(request_url, route_path, root_path, expected_root_url):
request = Request({"path": request_url, "type": "http", "headers": {}})
assert get_root_url(request, route_path, root_path) == expected_root_url