Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal download API: Add proper validated directory input #4981

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model_filemanager/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
145 changes: 69 additions & 76 deletions model_filemanager/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import traceback
import logging
from folder_paths import models_dir
from folder_paths import folder_names_and_paths, get_folder_paths
import re
from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum
Expand All @@ -17,6 +17,7 @@ class DownloadStatusType(Enum):
COMPLETED = "completed"
ERROR = "error"


@dataclass
class DownloadModelStatus():
status: str
Expand All @@ -29,7 +30,7 @@ def __init__(self, status: DownloadStatusType, progress_percentage: float, messa
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed

def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
Expand All @@ -38,102 +39,112 @@ def to_dict(self) -> Dict[str, Any]:
"already_existed": self.already_existed
}


async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_sub_directory: str,
model_name: str,
model_url: str,
model_directory: str,
folder_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.

Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
model_url (str):
The URL from which to download the model.
model_sub_directory (str):
The subdirectory within the main models directory where the model
model_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
folder_path (str);
Path to which model folder should be used as the root.

Returns:
DownloadModelStatus: The result of the download operation.
"""
if not validate_model_subdirectory(model_sub_directory):
if not validate_filename(model_name):
return DownloadModelStatus(
DownloadStatusType.ERROR,
DownloadStatusType.ERROR,
0,
"Invalid model subdirectory",
"Invalid model name",
False
)

if not validate_filename(model_name):
if not model_directory in folder_names_and_paths:
return DownloadModelStatus(
DownloadStatusType.ERROR,
DownloadStatusType.ERROR,
0,
"Invalid model name",
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
False
)

file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if not folder_path in get_folder_paths(model_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
False
)

file_path = create_model_path(model_name, folder_path)
existing_file = await check_file_exists(file_path, model_name, progress_callback)
if existing_file:
return existing_file

try:
logging.info(f"Downloading {model_name} from {model_url}")
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)

response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)

return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)

except Exception as e:
logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path)

return await handle_download_error(e, model_name, progress_callback)

def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name)

def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
os.makedirs(folder_path, exist_ok=True)
file_path = os.path.join(folder_path, model_name)

# Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(str(models_base_dir))
abs_base_dir = os.path.abspath(folder_path)
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")

return file_path

relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path

async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
return status
return None


async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str,
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
interval: float = 1.0) -> DownloadModelStatus:
try:
total_size = int(response.headers.get('Content-Length', 0))
Expand All @@ -144,10 +155,11 @@ async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
last_update_time = time.time()

with open(file_path, 'wb') as f:
temp_file_path = file_path + '.tmp'
with open(temp_file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
Expand All @@ -156,58 +168,39 @@ async def update_progress():
break
f.write(chunk)
downloaded += len(chunk)

if time.time() - last_update_time >= interval:
await update_progress()

os.rename(temp_file_path, file_path)

await update_progress()

logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)

return status
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path)
return await handle_download_error(e, model_name, progress_callback)


async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any],
relative_path: str) -> DownloadModelStatus:
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any]
) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
await progress_callback(model_name, status)
return status

def validate_model_subdirectory(model_subdirectory: str) -> bool:
"""
Validate that the model subdirectory is safe to install into.
Must not contain relative paths, nested paths or special characters
other than underscores and hyphens.

Args:
model_subdirectory (str): The subdirectory for the specific model type.

Returns:
bool: True if the subdirectory is safe, False otherwise.
"""
if len(model_subdirectory) > 50:
return False

if '..' in model_subdirectory or '/' in model_subdirectory:
return False

if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory):
return False

return True

def validate_filename(filename: str)-> bool:
"""
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.

Args:
filename (str): The filename to validate

Expand Down
5 changes: 3 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,18 +684,19 @@ async def report_progress(filename: str, status: DownloadModelStatus):
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
folder_path = data.get('folder_path')
model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.

if not url or not model_directory or not model_filename:
if not url or not model_directory or not model_filename or not folder_path:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)

session = self.client_session
if session is None:
logging.error("Client session is not initialized")
return web.Response(status=500)

task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval))
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
await task

return web.json_response(task.result().to_dict())
Expand Down
Loading
Loading