Skip to content

improved coverage of unit tests. #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 92 additions & 34 deletions floatcsep/infrastructure/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,49 @@


class Task:
"""
Represents a unit of work to be executed later as part of a task graph.

A Task wraps an object instance, a method, and its arguments to allow for deferred
execution. This is useful in workflows where tasks need to be executed in a specific order,
often dictated by dependencies on other tasks.

For instance, can wrap a floatcsep.model.Model, its method 'create_forecast' and the
argument 'time_window', which can be executed later with Task.call() when, for example,
task dependencies (parent nodes) have been completed.
"""

def __init__(self, instance, method, **kwargs):
"""
Base node of the workload distribution. Wraps lazily objects, methods and their
arguments for them to be executed later. For instance, can wrap a floatcsep.Model, its
method 'create_forecast' and the argument 'time_window', which can be executed later
with Task.call() when, for example, task dependencies (parent nodes) have been completed.

Args:
instance: can be floatcsep.Experiment, floatcsep.Model, floatcsep.Evaluation
method: the instance's method to be lazily created
**kwargs: keyword arguments passed to method.
instance: The object instance whose method will be executed later.
method (str): The method of the instance that will be called.
**kwargs: Arguments to pass to the method when it is invoked.

"""

self.obj = instance
self.method = method
self.kwargs = kwargs

self.store = None # Bool for nested tasks. DEPRECATED
self.store = None # Bool for nested tasks.

def sign_match(self, obj=None, met=None, kw_arg=None):
"""
Checks if the Task matches a given signature for simplicity.
Checks whether the task matches a given function signature.

This method is used to verify if a task belongs to a given object, method, or if it
uses a specific keyword argument. Useful for identifying tasks in a graph based on
partial matches of their attributes.

Purpose is to check from the outside if the Task is from a given object
(Model, Experiment, etc.), matching either name or object or description
Args:
obj: Instance or instance's name str. Instance is preferred
met: Name of the method
kw_arg: Only the value (not key) of the kwargs dictionary
obj: The object instance or its name (str) to match against.
met: The method name to match against.
kw_arg: A specific keyword argument value to match against in the task's arguments.

Returns:
bool: True if the task matches the provided signature, False otherwise.
"""

if self.obj == obj or obj == getattr(self.obj, "name", None):
Expand All @@ -43,6 +54,13 @@ def sign_match(self, obj=None, met=None, kw_arg=None):
return False

def __str__(self):
"""
Returns a string representation of the task, including the instance name, method, and
arguments. Useful for debugging purposes.

Returns:
str: A formatted string describing the task.
"""
task_str = f"{self.__class__}\n\t" f"Instance: {self.obj.__class__.__name__}\n"
a = getattr(self.obj, "name", None)
if a:
Expand All @@ -54,6 +72,16 @@ def __str__(self):
return task_str[:-2]

def run(self):
"""
Executes the task by calling the method on the object instance with the stored
arguments. If the instance has a `store` attribute, it will use that instead of the
instance itself. Once executed, the result is stored in the `store` attribute if any
output is produced.

Returns:
The output of the method execution, or None if the method does not return anything.
"""

if hasattr(self.obj, "store"):
self.obj = self.obj.store
output = getattr(self.obj, self.method)(**self.kwargs)
Expand All @@ -65,6 +93,12 @@ def run(self):
return output

def __call__(self, *args, **kwargs):
"""
A callable alias for the `run` method. Allows the task to be invoked directly.

Returns:
The result of the `run` method.
"""
return self.run()

def check_exist(self):
Expand All @@ -73,21 +107,35 @@ def check_exist(self):

class TaskGraph:
"""
Context manager of floatcsep workload distribution.

Assign tasks to a node and defines their dependencies (parent nodes).
Contains a 'tasks' dictionary whose dict_keys are the Task to be
executed with dict_values as the Task's dependencies.
Context manager of floatcsep workload distribution. A TaskGraph is responsible for adding
tasks, managing dependencies between tasks, and executing tasks in the correct order.
Tasks in the graph can depend on one another, and the graph ensures that each task is run
after all of its dependencies have been satisfied. Contains a 'tasks' dictionary whose
dict_keys are the Task to be executed with dict_values as the Task's dependencies.

Attributes:
tasks (OrderedDict): A dictionary where the keys are Task objects and the values are
lists of dependent Task objects.
_ntasks (int): The current number of tasks in the graph.
name (str): A name identifier for the task graph.
"""

def __init__(self):

"""
Initializes the TaskGraph with an empty task dictionary and task count.
"""
self.tasks = OrderedDict()
self._ntasks = 0
self.name = "floatcsep.utils.TaskGraph"
self.name = "floatcsep.infrastructure.engine.TaskGraph"

@property
def ntasks(self):
"""
Returns the number of tasks currently in the graph.

Returns:
int: The total number of tasks in the graph.
"""
return self._ntasks

@ntasks.setter
Expand All @@ -96,31 +144,32 @@ def ntasks(self, n):

def add(self, task):
"""
Simply adds a defined task to the graph.
Adds a new task to the task graph.

Args:
task: floatcsep.utils.Task
The task is added to the dictionary of tasks with no dependencies by default.

Returns:
Args:
task (Task): The task to be added to the graph.
"""
self.tasks[task] = []
self.ntasks += 1

def add_dependency(self, task, dinst=None, dmeth=None, dkw=None):
"""
Adds a dependency to a task already inserted to the TaskGraph.
Adds a dependency to a task already in the graph.

Searchs
within the pre-added tasks a signature match by their name/instance,
method and keyword_args.
Searches for other tasks within the graph whose signature matches the provided
object instance, method name, or keyword argument. Any matches are added as
dependencies to the provided task.

Args:
task: Task to which a dependency will be asigned
dinst: object/name of the dependency
dmeth: method of the dependency
dkw: keyword argument of the dependency
task (Task): The task to which dependencies will be added.
dinst: The object instance or name of the dependency.
dmeth: The method name of the dependency.
dkw: A specific keyword argument value of the dependency.

Returns:
None
"""
deps = []
for i, other_tasks in enumerate(self.tasks.keys()):
Expand All @@ -131,15 +180,24 @@ def add_dependency(self, task, dinst=None, dmeth=None, dkw=None):

def run(self):
"""
Iterates through all the graph tasks and runs them.
Executes all tasks in the task graph in the correct order based on dependencies.

Iterates over each task in the graph and runs it after its dependencies have been
resolved.

Returns:
None
"""
for task, deps in self.tasks.items():
task.run()

def __call__(self, *args, **kwargs):
"""
A callable alias for the `run` method. Allows the task graph to be invoked directly.

Returns:
None
"""
return self.run()

def check_exist(self):
Expand Down
109 changes: 109 additions & 0 deletions tests/unit/test_commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import unittest
from unittest.mock import patch, MagicMock
import floatcsep.commands.main as main_module


class TestMainModule(unittest.TestCase):

@patch('floatcsep.commands.main.Experiment')
@patch('floatcsep.commands.main.plot_catalogs')
@patch('floatcsep.commands.main.plot_forecasts')
@patch('floatcsep.commands.main.plot_results')
@patch('floatcsep.commands.main.plot_custom')
@patch('floatcsep.commands.main.generate_report')
def test_run(self, mock_generate_report, mock_plot_custom, mock_plot_results,
mock_plot_forecasts, mock_plot_catalogs, mock_experiment):
# Mock Experiment instance and its methods
mock_exp_instance = MagicMock()
mock_experiment.from_yml.return_value = mock_exp_instance

# Call the function
main_module.run(config='dummy_config')

# Verify the calls to the Experiment class methods
mock_experiment.from_yml.assert_called_once_with(config_yml='dummy_config')
mock_exp_instance.stage_models.assert_called_once()
mock_exp_instance.set_tasks.assert_called_once()
mock_exp_instance.run.assert_called_once()

# Verify that plotting and report generation functions were called
mock_plot_catalogs.assert_called_once_with(experiment=mock_exp_instance)
mock_plot_forecasts.assert_called_once_with(experiment=mock_exp_instance)
mock_plot_results.assert_called_once_with(experiment=mock_exp_instance)
mock_plot_custom.assert_called_once_with(experiment=mock_exp_instance)
mock_generate_report.assert_called_once_with(experiment=mock_exp_instance)

@patch('floatcsep.commands.main.Experiment')
def test_stage(self, mock_experiment):
# Mock Experiment instance and its methods
mock_exp_instance = MagicMock()
mock_experiment.from_yml.return_value = mock_exp_instance

# Call the function
main_module.stage(config='dummy_config')

# Verify the calls to the Experiment class methods
mock_experiment.from_yml.assert_called_once_with(config_yml='dummy_config')
mock_exp_instance.stage_models.assert_called_once()

@patch('floatcsep.commands.main.Experiment')
@patch('floatcsep.commands.main.plot_catalogs')
@patch('floatcsep.commands.main.plot_forecasts')
@patch('floatcsep.commands.main.plot_results')
@patch('floatcsep.commands.main.plot_custom')
@patch('floatcsep.commands.main.generate_report')
def test_plot(self, mock_generate_report, mock_plot_custom, mock_plot_results,
mock_plot_forecasts, mock_plot_catalogs, mock_experiment):
# Mock Experiment instance and its methods
mock_exp_instance = MagicMock()
mock_experiment.from_yml.return_value = mock_exp_instance

# Call the function
main_module.plot(config='dummy_config')

# Verify the calls to the Experiment class methods
mock_experiment.from_yml.assert_called_once_with(config_yml='dummy_config')
mock_exp_instance.stage_models.assert_called_once()
mock_exp_instance.set_tasks.assert_called_once()

# Verify that plotting and report generation functions were called
mock_plot_catalogs.assert_called_once_with(experiment=mock_exp_instance)
mock_plot_forecasts.assert_called_once_with(experiment=mock_exp_instance)
mock_plot_results.assert_called_once_with(experiment=mock_exp_instance)
mock_plot_custom.assert_called_once_with(experiment=mock_exp_instance)
mock_generate_report.assert_called_once_with(experiment=mock_exp_instance)

@patch('floatcsep.commands.main.Experiment')
@patch('floatcsep.commands.main.ExperimentComparison')
@patch('floatcsep.commands.main.reproducibility_report')
def test_reproduce(self, mock_reproducibility_report, mock_exp_comparison, mock_experiment):
# Mock Experiment instances and methods
mock_reproduced_exp = MagicMock()
mock_original_exp = MagicMock()
mock_experiment.from_yml.side_effect = [mock_reproduced_exp, mock_original_exp]

mock_comp_instance = MagicMock()
mock_exp_comparison.return_value = mock_comp_instance

# Call the function
main_module.reproduce(config='dummy_config')

# Verify the calls to the Experiment class methods
mock_experiment.from_yml.assert_any_call('dummy_config', repr_dir="reproduced")
mock_reproduced_exp.stage_models.assert_called_once()
mock_reproduced_exp.set_tasks.assert_called_once()
mock_reproduced_exp.run.assert_called_once()

mock_experiment.from_yml.assert_any_call(mock_reproduced_exp.original_config,
rundir=mock_reproduced_exp.original_run_dir)
mock_original_exp.stage_models.assert_called_once()
mock_original_exp.set_tasks.assert_called_once()

# Verify comparison and reproducibility report calls
mock_exp_comparison.assert_called_once_with(mock_original_exp, mock_reproduced_exp)
mock_comp_instance.compare_results.assert_called_once()
mock_reproducibility_report.assert_called_once_with(exp_comparison=mock_comp_instance)


if __name__ == '__main__':
unittest.main()
Loading
Loading