From 5a54e97ae6210bedeaadb6c1af19b17143a2393a Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 25 Sep 2024 11:19:17 +0900 Subject: [PATCH] Revert "Internal download API: Add proper validated directory input (#4981)" This reverts commit 08c8968482f96c36dbf1f6af36a228d13e5a432b. --- model_filemanager/__init__.py | 2 +- model_filemanager/download_models.py | 145 ++++++------ server.py | 5 +- .../download_models_test.py | 206 ++++++++---------- 4 files changed, 174 insertions(+), 184 deletions(-) diff --git a/model_filemanager/__init__.py b/model_filemanager/__init__.py index b7ac16256ac1..e318351c0512 100644 --- a/model_filemanager/__init__.py +++ b/model_filemanager/__init__.py @@ -1,2 +1,2 @@ # model_manager/__init__.py -from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename +from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 5ffec395e2d3..712d59328f63 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -3,7 +3,7 @@ import os import traceback import logging -from folder_paths import folder_names_and_paths, get_folder_paths +from folder_paths import models_dir import re from typing import Callable, Any, Optional, Awaitable, Dict from enum import Enum @@ -17,7 +17,6 @@ class DownloadStatusType(Enum): COMPLETED = "completed" ERROR = "error" - @dataclass class DownloadModelStatus(): status: str @@ -30,7 +29,7 @@ def __init__(self, status: DownloadStatusType, progress_percentage: float, messa self.progress_percentage = progress_percentage self.message = message self.already_existed = already_existed - + def to_dict(self) -> Dict[str, Any]: return { "status": self.status, @@ -39,112 +38,102 @@ def to_dict(self) -> Dict[str, Any]: "already_existed": self.already_existed } - async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], - model_name: str, - model_url: str, - model_directory: str, - folder_path: str, + model_name: str, + model_url: str, + model_sub_directory: str, progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_interval: float = 1.0) -> DownloadModelStatus: """ Download a model file from a given URL into the models directory. Args: - model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): + model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): A function that makes an HTTP request. This makes it easier to mock in unit tests. - model_name (str): + model_name (str): The name of the model file to be downloaded. This will be the filename on disk. - model_url (str): + model_url (str): The URL from which to download the model. - model_directory (str): - The subdirectory within the main models directory where the model + model_sub_directory (str): + The subdirectory within the main models directory where the model should be saved (e.g., 'checkpoints', 'loras', etc.). - progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): + progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): An asynchronous function to call with progress updates. - folder_path (str); - Path to which model folder should be used as the root. Returns: DownloadModelStatus: The result of the download operation. """ - if not validate_filename(model_name): + if not validate_model_subdirectory(model_sub_directory): return DownloadModelStatus( - DownloadStatusType.ERROR, + DownloadStatusType.ERROR, 0, - "Invalid model name", + "Invalid model subdirectory", False ) - if not model_directory in folder_names_and_paths: - return DownloadModelStatus( - DownloadStatusType.ERROR, - 0, - "Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.", - False - ) - - if not folder_path in get_folder_paths(model_directory): + if not validate_filename(model_name): return DownloadModelStatus( - DownloadStatusType.ERROR, + DownloadStatusType.ERROR, 0, - f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.", + "Invalid model name", False ) - file_path = create_model_path(model_name, folder_path) - existing_file = await check_file_exists(file_path, model_name, progress_callback) + file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) + existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) if existing_file: return existing_file try: - logging.info(f"Downloading {model_name} from {model_url}") status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) - await progress_callback(model_name, status) + await progress_callback(relative_path, status) response = await model_download_request(model_url) if response.status != 200: error_message = f"Failed to download {model_name}. Status code: {response.status}" logging.error(error_message) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - await progress_callback(model_name, status) + await progress_callback(relative_path, status) return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval) + return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) except Exception as e: logging.error(f"Error in downloading model: {e}") - return await handle_download_error(e, model_name, progress_callback) + return await handle_download_error(e, model_name, progress_callback, relative_path) + - -def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]: - os.makedirs(folder_path, exist_ok=True) - file_path = os.path.join(folder_path, model_name) +def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]: + full_model_dir = os.path.join(models_base_dir, model_directory) + os.makedirs(full_model_dir, exist_ok=True) + file_path = os.path.join(full_model_dir, model_name) # Ensure the resulting path is still within the base directory abs_file_path = os.path.abspath(file_path) - abs_base_dir = os.path.abspath(folder_path) + abs_base_dir = os.path.abspath(str(models_base_dir)) if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: - raise Exception(f"Invalid model directory: {folder_path}/{model_name}") + raise Exception(f"Invalid model directory: {model_directory}/{model_name}") - return file_path + relative_path = '/'.join([model_directory, model_name]) + return file_path, relative_path -async def check_file_exists(file_path: str, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]] - ) -> Optional[DownloadModelStatus]: +async def check_file_exists(file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + relative_path: str) -> Optional[DownloadModelStatus]: if os.path.exists(file_path): status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) - await progress_callback(model_name, status) + await progress_callback(relative_path, status) return status return None -async def track_download_progress(response: aiohttp.ClientResponse, - file_path: str, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], +async def track_download_progress(response: aiohttp.ClientResponse, + file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + relative_path: str, interval: float = 1.0) -> DownloadModelStatus: try: total_size = int(response.headers.get('Content-Length', 0)) @@ -155,11 +144,10 @@ async def update_progress(): nonlocal last_update_time progress = (downloaded / total_size) * 100 if total_size > 0 else 0 status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) - await progress_callback(model_name, status) + await progress_callback(relative_path, status) last_update_time = time.time() - temp_file_path = file_path + '.tmp' - with open(temp_file_path, 'wb') as f: + with open(file_path, 'wb') as f: chunk_iterator = response.content.iter_chunked(8192) while True: try: @@ -168,39 +156,58 @@ async def update_progress(): break f.write(chunk) downloaded += len(chunk) - + if time.time() - last_update_time >= interval: await update_progress() - os.rename(temp_file_path, file_path) - await update_progress() - + logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) - await progress_callback(model_name, status) + await progress_callback(relative_path, status) return status except Exception as e: logging.error(f"Error in track_download_progress: {e}") logging.error(traceback.format_exc()) - return await handle_download_error(e, model_name, progress_callback) - + return await handle_download_error(e, model_name, progress_callback, relative_path) -async def handle_download_error(e: Exception, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Any] - ) -> DownloadModelStatus: +async def handle_download_error(e: Exception, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Any], + relative_path: str) -> DownloadModelStatus: error_message = f"Error downloading {model_name}: {str(e)}" status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - await progress_callback(model_name, status) + await progress_callback(relative_path, status) return status +def validate_model_subdirectory(model_subdirectory: str) -> bool: + """ + Validate that the model subdirectory is safe to install into. + Must not contain relative paths, nested paths or special characters + other than underscores and hyphens. + + Args: + model_subdirectory (str): The subdirectory for the specific model type. + + Returns: + bool: True if the subdirectory is safe, False otherwise. + """ + if len(model_subdirectory) > 50: + return False + + if '..' in model_subdirectory or '/' in model_subdirectory: + return False + + if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory): + return False + + return True def validate_filename(filename: str)-> bool: """ Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. - + Args: filename (str): The filename to validate diff --git a/server.py b/server.py index f1971f2d2b5b..ea923e85ac38 100644 --- a/server.py +++ b/server.py @@ -689,11 +689,10 @@ async def report_progress(filename: str, status: DownloadModelStatus): data = await request.json() url = data.get('url') model_directory = data.get('model_directory') - folder_path = data.get('folder_path') model_filename = data.get('model_filename') progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress. - if not url or not model_directory or not model_filename or not folder_path: + if not url or not model_directory or not model_filename: return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) session = self.client_session @@ -701,7 +700,7 @@ async def report_progress(filename: str, status: DownloadModelStatus): logging.error("Client session is not initialized") return web.Response(status=500) - task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval)) + task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval)) await task return web.json_response(task.result().to_dict()) diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index 128dfeb9a11e..66150a4682fd 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -1,17 +1,10 @@ import pytest -import tempfile import aiohttp from aiohttp import ClientResponse import itertools -import os +import os from unittest.mock import AsyncMock, patch, MagicMock -from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename -import folder_paths - -@pytest.fixture -def temp_dir(): - with tempfile.TemporaryDirectory() as tmpdirname: - yield tmpdirname +from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename class AsyncIteratorMock: """ @@ -49,7 +42,7 @@ def iter_chunked(self, chunk_size): return AsyncIteratorMock(self.chunks) @pytest.mark.asyncio -async def test_download_model_success(temp_dir): +async def test_download_model_success(): mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.status = 200 mock_response.headers = {'Content-Length': '1000'} @@ -60,13 +53,15 @@ async def test_download_model_success(temp_dir): mock_make_request = AsyncMock(return_value=mock_response) mock_progress_callback = AsyncMock() + # Mock file operations + mock_open = MagicMock() + mock_file = MagicMock() + mock_open.return_value.__enter__.return_value = mock_file time_values = itertools.count(0, 0.1) - fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)} - - with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \ + with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ patch('model_filemanager.check_file_exists', return_value=None), \ - patch('folder_paths.folder_names_and_paths', fake_paths), \ + patch('builtins.open', mock_open), \ patch('time.time', side_effect=time_values): # Simulate time passing result = await download_model( @@ -74,7 +69,6 @@ async def test_download_model_success(temp_dir): 'model.sft', 'http://example.com/model.sft', 'checkpoints', - temp_dir, mock_progress_callback ) @@ -89,48 +83,44 @@ async def test_download_model_success(temp_dir): # Check initial call mock_progress_callback.assert_any_call( - 'model.sft', + 'checkpoints/model.sft', DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False) ) # Check final call mock_progress_callback.assert_any_call( - 'model.sft', + 'checkpoints/model.sft', DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) ) - mock_file_path = os.path.join(temp_dir, 'model.sft') - assert os.path.exists(mock_file_path) - with open(mock_file_path, 'rb') as mock_file: - assert mock_file.read() == b''.join(chunks) - os.remove(mock_file_path) + # Verify file writing + mock_file.write.assert_any_call(b'a' * 500) + mock_file.write.assert_any_call(b'b' * 300) + mock_file.write.assert_any_call(b'c' * 200) # Verify request was made mock_make_request.assert_called_once_with('http://example.com/model.sft') @pytest.mark.asyncio -async def test_download_model_url_request_failure(temp_dir): +async def test_download_model_url_request_failure(): # Mock dependencies mock_response = AsyncMock(spec=ClientResponse) mock_response.status = 404 # Simulate a "Not Found" error mock_get = AsyncMock(return_value=mock_response) mock_progress_callback = AsyncMock() - - fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)} # Mock the create_model_path function - with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \ - patch('model_filemanager.check_file_exists', return_value=None), \ - patch('folder_paths.folder_names_and_paths', fake_paths): - # Call the function - result = await download_model( - mock_get, - 'model.safetensors', - 'http://example.com/model.safetensors', - 'checkpoints', - temp_dir, - mock_progress_callback - ) + with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): + # Mock the check_file_exists function to return None (file doesn't exist) + with patch('model_filemanager.check_file_exists', return_value=None): + # Call the function + result = await download_model( + mock_get, + 'model.safetensors', + 'http://example.com/model.safetensors', + 'mock_directory', + mock_progress_callback + ) # Assert the expected behavior assert isinstance(result, DownloadModelStatus) @@ -140,7 +130,7 @@ async def test_download_model_url_request_failure(temp_dir): # Check that progress_callback was called with the correct arguments mock_progress_callback.assert_any_call( - 'model.safetensors', + 'mock_directory/model.safetensors', DownloadModelStatus( status=DownloadStatusType.PENDING, progress_percentage=0, @@ -149,7 +139,7 @@ async def test_download_model_url_request_failure(temp_dir): ) ) mock_progress_callback.assert_called_with( - 'model.safetensors', + 'mock_directory/model.safetensors', DownloadModelStatus( status=DownloadStatusType.ERROR, progress_percentage=0, @@ -163,125 +153,98 @@ async def test_download_model_url_request_failure(temp_dir): @pytest.mark.asyncio async def test_download_model_invalid_model_subdirectory(): + mock_make_request = AsyncMock() mock_progress_callback = AsyncMock() + result = await download_model( mock_make_request, 'model.sft', 'http://example.com/model.sft', '../bad_path', - '../bad_path', mock_progress_callback ) # Assert the result assert isinstance(result, DownloadModelStatus) - assert result.message.startswith('Invalid or unrecognized model directory') + assert result.message == 'Invalid model subdirectory' assert result.status == 'error' assert result.already_existed is False -@pytest.mark.asyncio -async def test_download_model_invalid_folder_path(): - mock_make_request = AsyncMock() - mock_progress_callback = AsyncMock() - - result = await download_model( - mock_make_request, - 'model.sft', - 'http://example.com/model.sft', - 'checkpoints', - 'invalid_path', - mock_progress_callback - ) - - # Assert the result - assert isinstance(result, DownloadModelStatus) - assert result.message.startswith("Invalid folder path") - assert result.status == 'error' - assert result.already_existed is False +# For create_model_path function def test_create_model_path(tmp_path, monkeypatch): - model_name = "model.safetensors" - folder_path = os.path.join(tmp_path, "mock_dir") - - file_path = create_model_path(model_name, folder_path) - - assert file_path == os.path.join(folder_path, "model.safetensors") + mock_models_dir = tmp_path / "models" + monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) + + model_name = "test_model.sft" + model_directory = "test_dir" + + file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir) + + assert file_path == str(mock_models_dir / model_directory / model_name) + assert relative_path == f"{model_directory}/{model_name}" assert os.path.exists(os.path.dirname(file_path)) - with pytest.raises(Exception, match="Invalid model directory"): - create_model_path("../path_traversal.safetensors", folder_path) - - with pytest.raises(Exception, match="Invalid model directory"): - create_model_path("/etc/some_root_path", folder_path) - @pytest.mark.asyncio async def test_check_file_exists_when_file_exists(tmp_path): file_path = tmp_path / "existing_model.sft" file_path.touch() # Create an empty file - + mock_callback = AsyncMock() - - result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback) - + + result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft") + assert result is not None assert result.status == "completed" assert result.message == "existing_model.sft already exists" assert result.already_existed is True - + mock_callback.assert_called_once_with( - "existing_model.sft", + "test/existing_model.sft", DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) ) @pytest.mark.asyncio async def test_check_file_exists_when_file_does_not_exist(tmp_path): file_path = tmp_path / "non_existing_model.sft" - + mock_callback = AsyncMock() - - result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback) - + + result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft") + assert result is None mock_callback.assert_not_called() @pytest.mark.asyncio -async def test_track_download_progress_no_content_length(temp_dir): +async def test_track_download_progress_no_content_length(): mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.headers = {} # No Content-Length header - chunks = [b'a' * 500, b'b' * 500] - mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks) + mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500]) mock_callback = AsyncMock() + mock_open = MagicMock(return_value=MagicMock()) - full_path = os.path.join(temp_dir, 'model.sft') - - result = await track_download_progress( - mock_response, full_path, 'model.sft', - mock_callback, interval=0.1 - ) + with patch('builtins.open', mock_open): + result = await track_download_progress( + mock_response, '/mock/path/model.sft', 'model.sft', + mock_callback, 'models/model.sft', interval=0.1 + ) assert result.status == "completed" - - assert os.path.exists(full_path) - with open(full_path, 'rb') as f: - assert f.read() == b''.join(chunks) - os.remove(full_path) - # Check that progress was reported even without knowing the total size mock_callback.assert_any_call( - 'model.sft', + 'models/model.sft', DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) ) @pytest.mark.asyncio -async def test_track_download_progress_interval(temp_dir): +async def test_track_download_progress_interval(): mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.headers = {'Content-Length': '1000'} - chunks = [b'a' * 100] * 10 - mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks) + mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10) mock_callback = AsyncMock() mock_open = MagicMock(return_value=MagicMock()) @@ -290,18 +253,18 @@ async def test_track_download_progress_interval(temp_dir): mock_time = MagicMock() mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks - full_path = os.path.join(temp_dir, 'model.sft') - - with patch('time.time', mock_time): + with patch('builtins.open', mock_open), \ + patch('time.time', mock_time): await track_download_progress( - mock_response, full_path, 'model.sft', - mock_callback, interval=1.0 + mock_response, '/mock/path/model.sft', 'model.sft', + mock_callback, 'models/model.sft', interval=1.0 ) - - assert os.path.exists(full_path) - with open(full_path, 'rb') as f: - assert f.read() == b''.join(chunks) - os.remove(full_path) + + # Print out the actual call count and the arguments of each call for debugging + print(f"mock_callback was called {mock_callback.call_count} times") + for i, call in enumerate(mock_callback.call_args_list): + args, kwargs = call + print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%") # Assert that progress was updated at least 3 times (start, at least one interval, and end) assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}" @@ -316,6 +279,27 @@ async def test_track_download_progress_interval(temp_dir): assert last_call[0][1].status == "completed" assert last_call[0][1].progress_percentage == 100 +def test_valid_subdirectory(): + assert validate_model_subdirectory("valid-model123") is True + +def test_subdirectory_too_long(): + assert validate_model_subdirectory("a" * 51) is False + +def test_subdirectory_with_double_dots(): + assert validate_model_subdirectory("model/../unsafe") is False + +def test_subdirectory_with_slash(): + assert validate_model_subdirectory("model/unsafe") is False + +def test_subdirectory_with_special_characters(): + assert validate_model_subdirectory("model@unsafe") is False + +def test_subdirectory_with_underscore_and_dash(): + assert validate_model_subdirectory("valid_model-name") is True + +def test_empty_subdirectory(): + assert validate_model_subdirectory("") is False + @pytest.mark.parametrize("filename, expected", [ ("valid_model.safetensors", True), ("valid_model.sft", True),