Skip to content

Commit

Permalink
Fixes race condition in update_root_in_config (#9306)
Browse files Browse the repository at this point in the history
* test

* lint

* tests

* add changeset

* change

* lint

* reduce num attempts

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot authored Sep 10, 2024
1 parent af4f70a commit f3f0fef
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 7 deletions.
5 changes: 5 additions & 0 deletions .changeset/pretty-hairs-rest.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Fixes race condition in `update_root_in_config`
9 changes: 4 additions & 5 deletions gradio/route_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,11 +687,10 @@ def update_root_in_config(config: BlocksConfigDict, root: str) -> BlocksConfigDi
root url has changed, all of the urls in the config that correspond to component
file urls are updated to use the new root url.
"""
with config_lock:
previous_root = config.get("root")
if previous_root is None or previous_root != root:
config["root"] = root
config = processing_utils.add_root_url(config, root, previous_root) # type: ignore
previous_root = config.get("root")
if previous_root is None or previous_root != root:
config["root"] = root
config = processing_utils.add_root_url(config, root, previous_root) # type: ignore
return config


Expand Down
4 changes: 2 additions & 2 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def main(request: fastapi.Request, user: str = Depends(get_current_user)):
request=request, route_path="/", root_path=app.root_path
)
if (app.auth is None and app.auth_dependency is None) or user is not None:
config = blocks.config
config = utils.safe_deepcopy(blocks.config)
config = route_utils.update_root_in_config(config, root)
config["username"] = user
elif app.auth_dependency:
Expand Down Expand Up @@ -450,7 +450,7 @@ def api_info(all_endpoints: bool = False):
@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
def get_config(request: fastapi.Request):
config = app.get_blocks().config
config = utils.safe_deepcopy(app.get_blocks().config)
root = route_utils.get_root_url(
request=request, route_path="/config", root_path=app.root_path
)
Expand Down
18 changes: 18 additions & 0 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,24 @@ def get_default_args(func: Callable) -> list[Any]:
]


def safe_deepcopy(obj: Any) -> Any:
try:
return copy.deepcopy(obj)
except Exception:
if isinstance(obj, dict):
return {
safe_deepcopy(key): safe_deepcopy(value) for key, value in obj.items()
}
elif isinstance(obj, list):
return [safe_deepcopy(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(safe_deepcopy(item) for item in obj)
elif isinstance(obj, set):
return {safe_deepcopy(item) for item in obj}
else:
return copy.copy(obj)


def assert_configs_are_equivalent_besides_ids(
config1: dict, config2: dict, root_keys: tuple = ("mode",)
):
Expand Down
57 changes: 57 additions & 0 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import time
from contextlib import asynccontextmanager, closing
from pathlib import Path
from threading import Thread
from unittest.mock import patch

import gradio_client as grc
import httpx
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -1548,3 +1550,58 @@ def test_bash_api_multiple_inputs_outputs():
assert response.status_code == 200
assert "event: complete\ndata:" in response.text
assert json.dumps([123, "abc"]) in response.text


def test_attacker_cannot_change_root_in_config(
attacker_threads=1, victim_threads=10, max_attempts=30
):
def attacker(url):
"""Simulates the attacker sending a request with a malicious header."""
for _ in range(max_attempts):
httpx.get(url + "config", headers={"X-Forwarded-Host": "evil"})

def victim(url, results):
"""Simulates the victim making a normal request and checking the response."""
for _ in range(max_attempts):
res = httpx.get(url)
config = json.loads(
res.text.split("window.gradio_config =", 1)[1].split(";</script>", 1)[0]
)
if "evil" in config["root"]:
results.append(True)
return

results.append(False)

with gr.Blocks() as demo:
i1 = gr.Image("test/test_files/cheetah1.jpg")
t = gr.Textbox()
i2 = gr.Image()
t.change(lambda x: x, i1, i2)

_, url, _ = demo.launch(prevent_thread_lock=True)

threads = []
results = []

for _ in range(attacker_threads):
t_attacker = Thread(target=attacker, args=(url,))
threads.append(t_attacker)

for _ in range(victim_threads):
t_victim = Thread(
target=victim,
args=(
url,
results,
),
)
threads.append(t_victim)

for t in threads:
t.start()

for t in threads:
t.join()

assert not any(results), "attacker was able to modify a victim's config root url"
42 changes: 42 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
is_in_or_equal,
is_special_typed_parameter,
kaggle_check,
safe_deepcopy,
sagemaker_check,
sanitize_list_for_csv,
sanitize_value_for_csv,
Expand Down Expand Up @@ -661,3 +662,44 @@ def test_get_nonexistent(self):
d = UnhashableKeyDict()
with pytest.raises(KeyError):
d["nonexistent"]


class TestSafeDeepCopy:
def test_safe_deepcopy_dict(self):
original = {"key1": [1, 2, {"nested_key": "value"}], "key2": "simple_string"}
copied = safe_deepcopy(original)

assert copied == original
assert copied is not original
assert copied["key1"] is not original["key1"]
assert copied["key1"][2] is not original["key1"][2]

def test_safe_deepcopy_list(self):
original = [1, 2, [3, 4, {"key": "value"}]]
copied = safe_deepcopy(original)

assert copied == original
assert copied is not original
assert copied[2] is not original[2]
assert copied[2][2] is not original[2][2]

def test_safe_deepcopy_custom_object(self):
class CustomClass:
def __init__(self, value):
self.value = value

original = CustomClass(10)
copied = safe_deepcopy(original)

assert copied.value == original.value
assert copied is not original

def test_safe_deepcopy_handles_undeepcopyable(self):
class Uncopyable:
def __deepcopy__(self, memo):
raise TypeError("Can't deepcopy")

original = Uncopyable()
result = safe_deepcopy(original)
assert result is not original
assert type(result) is type(original)

0 comments on commit f3f0fef

Please sign in to comment.