Skip to content

Commit

Permalink
Simplify unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
robinjhuang committed Sep 9, 2024
1 parent a3beceb commit 6ca34b1
Showing 1 changed file with 49 additions and 47 deletions.
96 changes: 49 additions & 47 deletions tests-unit/main_test.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,69 @@
import pytest
import yaml
import os
import logging
from unittest.mock import patch
from unittest.mock import Mock, patch, mock_open

# Import the function we're testing
from main import load_extra_path_config
import folder_paths

@pytest.fixture
def mock_yaml_file():
yaml_content = """
test_config:
base_path: ~/App/
checkpoint:
subfolder1
subfolder2
lora: otherfolder
"""
return yaml_content
def mock_yaml_content():
return {
'test_config': {
'base_path': '~/App/',
'model1': 'subfolder1',
}
}

@pytest.fixture
def mock_expanded_home():
return '/home/user'

@patch('os.path.expanduser')
@patch('folder_paths.add_model_folder_path')
def test_load_extra_path_config(mock_add_model_folder_path, mock_expanduser, mock_yaml_file, mock_expanded_home, tmp_path):
# Setup
mock_expanduser.return_value = os.path.join(mock_expanded_home, 'App')
yaml_path = tmp_path / "test_config.yaml"
with open(yaml_path, 'w') as f:
f.write(mock_yaml_file)
@pytest.fixture
def mock_add_model_folder_path():
return Mock()

# Call the function
load_extra_path_config(yaml_path)
@pytest.fixture
def mock_expanduser(mock_expanded_home):
def _expanduser(path):
if path.startswith('~/'):
return os.path.join(mock_expanded_home, path[2:])
return path
return _expanduser

# Assertions
expected_calls = [
('checkpoint', os.path.join(mock_expanded_home, 'App', 'subfolder1 subfolder2')),
('lora', os.path.join(mock_expanded_home, 'App', 'otherfolder'))
]
@pytest.fixture
def mock_yaml_safe_load(mock_yaml_content):
return Mock(return_value=mock_yaml_content)

assert mock_add_model_folder_path.call_count == len(expected_calls)
for call in mock_add_model_folder_path.call_args_list:
assert call.args in expected_calls
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
def test_load_extra_model_paths_expands_userpath(
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expanduser,
mock_yaml_safe_load,
mock_expanded_home
):
# Attach mocks used by load_extra_path_config
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
monkeypatch.setattr(os.path, 'expanduser', mock_expanduser)
monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load)

# Check if expanduser was called with the correct path
mock_expanduser.assert_called_once_with('~/App/')
load_extra_path_config('dummy_path.yaml')

@pytest.fixture
def caplog(caplog):
caplog.set_level(logging.INFO)
return caplog
expected_calls = [
('model1', os.path.join(mock_expanded_home, 'App', 'subfolder1')),
]

def test_load_extra_path_config_logging(mock_yaml_file, tmp_path, caplog):
# Setup
yaml_path = tmp_path / "test_config.yaml"
with open(yaml_path, 'w') as f:
f.write(mock_yaml_file)
assert mock_add_model_folder_path.call_count == len(expected_calls)

# Check if add_model_folder_path was called with the correct arguments
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call

# Call the function
with patch('folder_paths.add_model_folder_path'):
load_extra_path_config(yaml_path)
# Check if yaml.safe_load was called
mock_yaml_safe_load.assert_called_once()

# Check logged messages
assert "Adding extra search path checkpoint " in caplog.text
assert "Adding extra search path lora " in caplog.text
# Check if open was called with the correct file path
mock_file.assert_called_once_with('dummy_path.yaml', 'r')

0 comments on commit 6ca34b1

Please sign in to comment.