From 6ca34b1bb724564a7edb1c882020d8b3cba4c461 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Mon, 9 Sep 2024 17:24:43 +0900 Subject: [PATCH] Simplify unit test. --- tests-unit/main_test.py | 96 +++++++++++++++++++++-------------------- 1 file changed, 49 insertions(+), 47 deletions(-) diff --git a/tests-unit/main_test.py b/tests-unit/main_test.py index 8384266a9388..d0dca8252b8b 100644 --- a/tests-unit/main_test.py +++ b/tests-unit/main_test.py @@ -1,67 +1,69 @@ import pytest +import yaml import os -import logging -from unittest.mock import patch +from unittest.mock import Mock, patch, mock_open # Import the function we're testing from main import load_extra_path_config +import folder_paths @pytest.fixture -def mock_yaml_file(): - yaml_content = """ - test_config: - base_path: ~/App/ - checkpoint: - subfolder1 - subfolder2 - lora: otherfolder - """ - return yaml_content +def mock_yaml_content(): + return { + 'test_config': { + 'base_path': '~/App/', + 'model1': 'subfolder1', + } + } @pytest.fixture def mock_expanded_home(): return '/home/user' -@patch('os.path.expanduser') -@patch('folder_paths.add_model_folder_path') -def test_load_extra_path_config(mock_add_model_folder_path, mock_expanduser, mock_yaml_file, mock_expanded_home, tmp_path): - # Setup - mock_expanduser.return_value = os.path.join(mock_expanded_home, 'App') - yaml_path = tmp_path / "test_config.yaml" - with open(yaml_path, 'w') as f: - f.write(mock_yaml_file) +@pytest.fixture +def mock_add_model_folder_path(): + return Mock() - # Call the function - load_extra_path_config(yaml_path) +@pytest.fixture +def mock_expanduser(mock_expanded_home): + def _expanduser(path): + if path.startswith('~/'): + return os.path.join(mock_expanded_home, path[2:]) + return path + return _expanduser - # Assertions - expected_calls = [ - ('checkpoint', os.path.join(mock_expanded_home, 'App', 'subfolder1 subfolder2')), - ('lora', os.path.join(mock_expanded_home, 'App', 'otherfolder')) - ] +@pytest.fixture +def mock_yaml_safe_load(mock_yaml_content): + return Mock(return_value=mock_yaml_content) - assert mock_add_model_folder_path.call_count == len(expected_calls) - for call in mock_add_model_folder_path.call_args_list: - assert call.args in expected_calls +@patch('builtins.open', new_callable=mock_open, read_data="dummy file content") +def test_load_extra_model_paths_expands_userpath( + mock_file, + monkeypatch, + mock_add_model_folder_path, + mock_expanduser, + mock_yaml_safe_load, + mock_expanded_home +): + # Attach mocks used by load_extra_path_config + monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path) + monkeypatch.setattr(os.path, 'expanduser', mock_expanduser) + monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load) - # Check if expanduser was called with the correct path - mock_expanduser.assert_called_once_with('~/App/') + load_extra_path_config('dummy_path.yaml') -@pytest.fixture -def caplog(caplog): - caplog.set_level(logging.INFO) - return caplog + expected_calls = [ + ('model1', os.path.join(mock_expanded_home, 'App', 'subfolder1')), + ] -def test_load_extra_path_config_logging(mock_yaml_file, tmp_path, caplog): - # Setup - yaml_path = tmp_path / "test_config.yaml" - with open(yaml_path, 'w') as f: - f.write(mock_yaml_file) + assert mock_add_model_folder_path.call_count == len(expected_calls) + + # Check if add_model_folder_path was called with the correct arguments + for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): + assert actual_call.args == expected_call - # Call the function - with patch('folder_paths.add_model_folder_path'): - load_extra_path_config(yaml_path) + # Check if yaml.safe_load was called + mock_yaml_safe_load.assert_called_once() - # Check logged messages - assert "Adding extra search path checkpoint " in caplog.text - assert "Adding extra search path lora " in caplog.text + # Check if open was called with the correct file path + mock_file.assert_called_once_with('dummy_path.yaml', 'r')