From 9fa8faa44a0226eb9a0980006cf62ca2c1c4973a Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Mon, 9 Sep 2024 21:33:44 -0700 Subject: [PATCH] Expand user directory for basepath in extra_models_paths.yaml (#4857) * Expand user path. * Add test. * Add unit test for expanding base path. * Simplify unit test. * Remove comment. * Remove comment. * Checkpoints. * Refactor. --- main.py | 27 ++--------- tests-unit/utils/extra_config_test.py | 69 +++++++++++++++++++++++++++ utils/__init__.py | 0 utils/extra_config.py | 25 ++++++++++ 4 files changed, 97 insertions(+), 24 deletions(-) create mode 100644 tests-unit/utils/extra_config_test.py create mode 100644 utils/__init__.py create mode 100644 utils/extra_config.py diff --git a/main.py b/main.py index d791a169cd7..a1db97cd046 100644 --- a/main.py +++ b/main.py @@ -63,6 +63,7 @@ def execute_script(script_path): import gc import logging +from utils import extra_config if os.name == "nt": logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) @@ -85,7 +86,6 @@ def execute_script(script_path): pass import comfy.utils -import yaml import execution import server @@ -180,27 +180,6 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) -def load_extra_path_config(yaml_path): - with open(yaml_path, 'r') as stream: - config = yaml.safe_load(stream) - for c in config: - conf = config[c] - if conf is None: - continue - base_path = None - if "base_path" in conf: - base_path = conf.pop("base_path") - for x in conf: - for y in conf[x].split("\n"): - if len(y) == 0: - continue - full_path = y - if base_path is not None: - full_path = os.path.join(base_path, full_path) - logging.info("Adding extra search path {} {}".format(x, full_path)) - folder_paths.add_model_folder_path(x, full_path) - - if __name__ == "__main__": if args.temp_directory: temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") @@ -222,11 +201,11 @@ def load_extra_path_config(yaml_path): extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): - load_extra_path_config(extra_model_paths_config_path) + extra_config.load_extra_path_config(extra_model_paths_config_path) if args.extra_model_paths_config: for config_path in itertools.chain(*args.extra_model_paths_config): - load_extra_path_config(config_path) + extra_config.load_extra_path_config(config_path) nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes) diff --git a/tests-unit/utils/extra_config_test.py b/tests-unit/utils/extra_config_test.py new file mode 100644 index 00000000000..f56dd3e2ef1 --- /dev/null +++ b/tests-unit/utils/extra_config_test.py @@ -0,0 +1,69 @@ +import pytest +import yaml +import os +from unittest.mock import Mock, patch, mock_open + +from utils.extra_config import load_extra_path_config +import folder_paths + +@pytest.fixture +def mock_yaml_content(): + return { + 'test_config': { + 'base_path': '~/App/', + 'checkpoints': 'subfolder1', + } + } + +@pytest.fixture +def mock_expanded_home(): + return '/home/user' + +@pytest.fixture +def mock_add_model_folder_path(): + return Mock() + +@pytest.fixture +def mock_expanduser(mock_expanded_home): + def _expanduser(path): + if path.startswith('~/'): + return os.path.join(mock_expanded_home, path[2:]) + return path + return _expanduser + +@pytest.fixture +def mock_yaml_safe_load(mock_yaml_content): + return Mock(return_value=mock_yaml_content) + +@patch('builtins.open', new_callable=mock_open, read_data="dummy file content") +def test_load_extra_model_paths_expands_userpath( + mock_file, + monkeypatch, + mock_add_model_folder_path, + mock_expanduser, + mock_yaml_safe_load, + mock_expanded_home +): + # Attach mocks used by load_extra_path_config + monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path) + monkeypatch.setattr(os.path, 'expanduser', mock_expanduser) + monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load) + + dummy_yaml_file_name = 'dummy_path.yaml' + load_extra_path_config(dummy_yaml_file_name) + + expected_calls = [ + ('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1')), + ] + + assert mock_add_model_folder_path.call_count == len(expected_calls) + + # Check if add_model_folder_path was called with the correct arguments + for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): + assert actual_call.args == expected_call + + # Check if yaml.safe_load was called + mock_yaml_safe_load.assert_called_once() + + # Check if open was called with the correct file path + mock_file.assert_called_once_with(dummy_yaml_file_name, 'r') diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/utils/extra_config.py b/utils/extra_config.py new file mode 100644 index 00000000000..23c2d791c7a --- /dev/null +++ b/utils/extra_config.py @@ -0,0 +1,25 @@ +import os +import yaml +import folder_paths +import logging + +def load_extra_path_config(yaml_path): + with open(yaml_path, 'r') as stream: + config = yaml.safe_load(stream) + for c in config: + conf = config[c] + if conf is None: + continue + base_path = None + if "base_path" in conf: + base_path = conf.pop("base_path") + base_path = os.path.expanduser(base_path) + for x in conf: + for y in conf[x].split("\n"): + if len(y) == 0: + continue + full_path = y + if base_path is not None: + full_path = os.path.join(base_path, full_path) + logging.info("Adding extra search path {} {}".format(x, full_path)) + folder_paths.add_model_folder_path(x, full_path)