From f3f0fef199c7779aac9aaef794dd4af1861ce50f Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 10 Sep 2024 12:01:40 -0700 Subject: [PATCH] Fixes race condition in `update_root_in_config` (#9306) * test * lint * tests * add changeset * change * lint * reduce num attempts --------- Co-authored-by: gradio-pr-bot --- .changeset/pretty-hairs-rest.md | 5 +++ gradio/route_utils.py | 9 +++--- gradio/routes.py | 4 +-- gradio/utils.py | 18 +++++++++++ test/test_routes.py | 57 +++++++++++++++++++++++++++++++++ test/test_utils.py | 42 ++++++++++++++++++++++++ 6 files changed, 128 insertions(+), 7 deletions(-) create mode 100644 .changeset/pretty-hairs-rest.md diff --git a/.changeset/pretty-hairs-rest.md b/.changeset/pretty-hairs-rest.md new file mode 100644 index 0000000000000..45d7e62c28f57 --- /dev/null +++ b/.changeset/pretty-hairs-rest.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Fixes race condition in `update_root_in_config` diff --git a/gradio/route_utils.py b/gradio/route_utils.py index 643129e369800..c4ff6ef16b494 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -687,11 +687,10 @@ def update_root_in_config(config: BlocksConfigDict, root: str) -> BlocksConfigDi root url has changed, all of the urls in the config that correspond to component file urls are updated to use the new root url. """ - with config_lock: - previous_root = config.get("root") - if previous_root is None or previous_root != root: - config["root"] = root - config = processing_utils.add_root_url(config, root, previous_root) # type: ignore + previous_root = config.get("root") + if previous_root is None or previous_root != root: + config["root"] = root + config = processing_utils.add_root_url(config, root, previous_root) # type: ignore return config diff --git a/gradio/routes.py b/gradio/routes.py index 03d5fe5a7fa37..299f5798ba817 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -396,7 +396,7 @@ def main(request: fastapi.Request, user: str = Depends(get_current_user)): request=request, route_path="/", root_path=app.root_path ) if (app.auth is None and app.auth_dependency is None) or user is not None: - config = blocks.config + config = utils.safe_deepcopy(blocks.config) config = route_utils.update_root_in_config(config, root) config["username"] = user elif app.auth_dependency: @@ -450,7 +450,7 @@ def api_info(all_endpoints: bool = False): @app.get("/config/", dependencies=[Depends(login_check)]) @app.get("/config", dependencies=[Depends(login_check)]) def get_config(request: fastapi.Request): - config = app.get_blocks().config + config = utils.safe_deepcopy(app.get_blocks().config) root = route_utils.get_root_url( request=request, route_path="/config", root_path=app.root_path ) diff --git a/gradio/utils.py b/gradio/utils.py index 20d9e2a1306bc..6dde7ad7a8957 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -466,6 +466,24 @@ def get_default_args(func: Callable) -> list[Any]: ] +def safe_deepcopy(obj: Any) -> Any: + try: + return copy.deepcopy(obj) + except Exception: + if isinstance(obj, dict): + return { + safe_deepcopy(key): safe_deepcopy(value) for key, value in obj.items() + } + elif isinstance(obj, list): + return [safe_deepcopy(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(safe_deepcopy(item) for item in obj) + elif isinstance(obj, set): + return {safe_deepcopy(item) for item in obj} + else: + return copy.copy(obj) + + def assert_configs_are_equivalent_besides_ids( config1: dict, config2: dict, root_keys: tuple = ("mode",) ): diff --git a/test/test_routes.py b/test/test_routes.py index 79ed5a3136864..19162a4a40bcc 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -8,9 +8,11 @@ import time from contextlib import asynccontextmanager, closing from pathlib import Path +from threading import Thread from unittest.mock import patch import gradio_client as grc +import httpx import numpy as np import pandas as pd import pytest @@ -1548,3 +1550,58 @@ def test_bash_api_multiple_inputs_outputs(): assert response.status_code == 200 assert "event: complete\ndata:" in response.text assert json.dumps([123, "abc"]) in response.text + + +def test_attacker_cannot_change_root_in_config( + attacker_threads=1, victim_threads=10, max_attempts=30 +): + def attacker(url): + """Simulates the attacker sending a request with a malicious header.""" + for _ in range(max_attempts): + httpx.get(url + "config", headers={"X-Forwarded-Host": "evil"}) + + def victim(url, results): + """Simulates the victim making a normal request and checking the response.""" + for _ in range(max_attempts): + res = httpx.get(url) + config = json.loads( + res.text.split("window.gradio_config =", 1)[1].split(";", 1)[0] + ) + if "evil" in config["root"]: + results.append(True) + return + + results.append(False) + + with gr.Blocks() as demo: + i1 = gr.Image("test/test_files/cheetah1.jpg") + t = gr.Textbox() + i2 = gr.Image() + t.change(lambda x: x, i1, i2) + + _, url, _ = demo.launch(prevent_thread_lock=True) + + threads = [] + results = [] + + for _ in range(attacker_threads): + t_attacker = Thread(target=attacker, args=(url,)) + threads.append(t_attacker) + + for _ in range(victim_threads): + t_victim = Thread( + target=victim, + args=( + url, + results, + ), + ) + threads.append(t_victim) + + for t in threads: + t.start() + + for t in threads: + t.join() + + assert not any(results), "attacker was able to modify a victim's config root url" diff --git a/test/test_utils.py b/test/test_utils.py index 84c8c7c5966dc..bbfc80b67cfb6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -36,6 +36,7 @@ is_in_or_equal, is_special_typed_parameter, kaggle_check, + safe_deepcopy, sagemaker_check, sanitize_list_for_csv, sanitize_value_for_csv, @@ -661,3 +662,44 @@ def test_get_nonexistent(self): d = UnhashableKeyDict() with pytest.raises(KeyError): d["nonexistent"] + + +class TestSafeDeepCopy: + def test_safe_deepcopy_dict(self): + original = {"key1": [1, 2, {"nested_key": "value"}], "key2": "simple_string"} + copied = safe_deepcopy(original) + + assert copied == original + assert copied is not original + assert copied["key1"] is not original["key1"] + assert copied["key1"][2] is not original["key1"][2] + + def test_safe_deepcopy_list(self): + original = [1, 2, [3, 4, {"key": "value"}]] + copied = safe_deepcopy(original) + + assert copied == original + assert copied is not original + assert copied[2] is not original[2] + assert copied[2][2] is not original[2][2] + + def test_safe_deepcopy_custom_object(self): + class CustomClass: + def __init__(self, value): + self.value = value + + original = CustomClass(10) + copied = safe_deepcopy(original) + + assert copied.value == original.value + assert copied is not original + + def test_safe_deepcopy_handles_undeepcopyable(self): + class Uncopyable: + def __deepcopy__(self, memo): + raise TypeError("Can't deepcopy") + + original = Uncopyable() + result = safe_deepcopy(original) + assert result is not original + assert type(result) is type(original)