Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes race condition in update_root_in_config #9306

Merged
merged 10 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/pretty-hairs-rest.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Fixes race condition in `update_root_in_config`
9 changes: 4 additions & 5 deletions gradio/route_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,11 +687,10 @@ def update_root_in_config(config: BlocksConfigDict, root: str) -> BlocksConfigDi
root url has changed, all of the urls in the config that correspond to component
file urls are updated to use the new root url.
"""
with config_lock:
previous_root = config.get("root")
if previous_root is None or previous_root != root:
config["root"] = root
config = processing_utils.add_root_url(config, root, previous_root) # type: ignore
previous_root = config.get("root")
if previous_root is None or previous_root != root:
config["root"] = root
config = processing_utils.add_root_url(config, root, previous_root) # type: ignore
return config


Expand Down
4 changes: 2 additions & 2 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def main(request: fastapi.Request, user: str = Depends(get_current_user)):
request=request, route_path="/", root_path=app.root_path
)
if (app.auth is None and app.auth_dependency is None) or user is not None:
config = blocks.config
config = utils.safe_deepcopy(blocks.config)
config = route_utils.update_root_in_config(config, root)
config["username"] = user
elif app.auth_dependency:
Expand Down Expand Up @@ -450,7 +450,7 @@ def api_info(all_endpoints: bool = False):
@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
def get_config(request: fastapi.Request):
config = app.get_blocks().config
config = utils.safe_deepcopy(app.get_blocks().config)
root = route_utils.get_root_url(
request=request, route_path="/config", root_path=app.root_path
)
Expand Down
18 changes: 18 additions & 0 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,24 @@ def get_default_args(func: Callable) -> list[Any]:
]


def safe_deepcopy(obj: Any) -> Any:
try:
return copy.deepcopy(obj)
except Exception:
if isinstance(obj, dict):
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
return {
safe_deepcopy(key): safe_deepcopy(value) for key, value in obj.items()
}
elif isinstance(obj, list):
return [safe_deepcopy(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(safe_deepcopy(item) for item in obj)
elif isinstance(obj, set):
return {safe_deepcopy(item) for item in obj}
else:
return copy.copy(obj)


def assert_configs_are_equivalent_besides_ids(
config1: dict, config2: dict, root_keys: tuple = ("mode",)
):
Expand Down
57 changes: 57 additions & 0 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import time
from contextlib import asynccontextmanager, closing
from pathlib import Path
from threading import Thread
from unittest.mock import patch

import gradio_client as grc
import httpx
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -1548,3 +1550,58 @@ def test_bash_api_multiple_inputs_outputs():
assert response.status_code == 200
assert "event: complete\ndata:" in response.text
assert json.dumps([123, "abc"]) in response.text


def test_attacker_cannot_change_root_in_config(
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
attacker_threads=1, victim_threads=10, max_attempts=30
):
def attacker(url):
"""Simulates the attacker sending a request with a malicious header."""
for _ in range(max_attempts):
httpx.get(url + "config", headers={"X-Forwarded-Host": "evil"})

def victim(url, results):
"""Simulates the victim making a normal request and checking the response."""
for _ in range(max_attempts):
res = httpx.get(url)
config = json.loads(
res.text.split("window.gradio_config =", 1)[1].split(";</script>", 1)[0]
)
if "evil" in config["root"]:
results.append(True)
return

results.append(False)

with gr.Blocks() as demo:
i1 = gr.Image("test/test_files/cheetah1.jpg")
t = gr.Textbox()
i2 = gr.Image()
t.change(lambda x: x, i1, i2)

_, url, _ = demo.launch(prevent_thread_lock=True)

threads = []
results = []

for _ in range(attacker_threads):
t_attacker = Thread(target=attacker, args=(url,))
threads.append(t_attacker)

for _ in range(victim_threads):
t_victim = Thread(
target=victim,
args=(
url,
results,
),
)
threads.append(t_victim)

for t in threads:
t.start()

for t in threads:
t.join()

assert not any(results), "attacker was able to modify a victim's config root url"
42 changes: 42 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
is_in_or_equal,
is_special_typed_parameter,
kaggle_check,
safe_deepcopy,
sagemaker_check,
sanitize_list_for_csv,
sanitize_value_for_csv,
Expand Down Expand Up @@ -661,3 +662,44 @@ def test_get_nonexistent(self):
d = UnhashableKeyDict()
with pytest.raises(KeyError):
d["nonexistent"]


class TestSafeDeepCopy:
def test_safe_deepcopy_dict(self):
original = {"key1": [1, 2, {"nested_key": "value"}], "key2": "simple_string"}
copied = safe_deepcopy(original)

assert copied == original
assert copied is not original
assert copied["key1"] is not original["key1"]
assert copied["key1"][2] is not original["key1"][2]

def test_safe_deepcopy_list(self):
original = [1, 2, [3, 4, {"key": "value"}]]
copied = safe_deepcopy(original)

assert copied == original
assert copied is not original
assert copied[2] is not original[2]
assert copied[2][2] is not original[2][2]

def test_safe_deepcopy_custom_object(self):
class CustomClass:
def __init__(self, value):
self.value = value

original = CustomClass(10)
copied = safe_deepcopy(original)

assert copied.value == original.value
assert copied is not original

def test_safe_deepcopy_handles_undeepcopyable(self):
class Uncopyable:
def __deepcopy__(self, memo):
raise TypeError("Can't deepcopy")

original = Uncopyable()
result = safe_deepcopy(original)
assert result is not original
assert type(result) is type(original)
Loading