Skip to content

Commit

Permalink
Add additional input checks
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-janssen committed Oct 29, 2024
1 parent ac98c42 commit 29d68fe
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 16 deletions.
5 changes: 5 additions & 0 deletions executorlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from executorlib.standalone.inputcheck import (
check_refresh_rate as _check_refresh_rate,
)
from executorlib.standalone.inputcheck import (
check_pysqa_config_directory as _check_pysqa_config_directory
)

__version__ = _get_versions()["version"]
__all__ = []
Expand Down Expand Up @@ -194,6 +197,7 @@ def __new__(
init_function=init_function,
)
elif not disable_dependencies:
_check_pysqa_config_directory(pysqa_config_directory=pysqa_config_directory)
return ExecutorWithDependencies(
max_workers=max_workers,
backend=backend,
Expand All @@ -210,6 +214,7 @@ def __new__(
plot_dependency_graph=plot_dependency_graph,
)
else:
_check_pysqa_config_directory(pysqa_config_directory=pysqa_config_directory)
_check_plot_dependency_graph(plot_dependency_graph=plot_dependency_graph)
_check_refresh_rate(refresh_rate=refresh_rate)
return create_executor(
Expand Down
22 changes: 6 additions & 16 deletions executorlib/cache/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from executorlib.standalone.inputcheck import (
check_executor,
check_nested_flux_executor,
check_flux_executor_pmi_mode,
check_max_workers_and_cores,
check_hostname_localhost,
)
from executorlib.standalone.thread import RaisingThread

Expand Down Expand Up @@ -89,18 +92,6 @@ def create_file_executor(
):
if cache_directory is None:
cache_directory = "executorlib_cache"
if max_workers != 1:
raise ValueError(
"The number of workers cannot be controlled with the pysqa based backend."
)
if max_cores != 1:
raise ValueError(
"The number of cores cannot be controlled with the pysqa based backend."
)
if hostname_localhost is not None:
raise ValueError(
"The option to connect to hosts based on their hostname is not available with the pysqa based backend."
)
if block_allocation:
raise ValueError(
"The option block_allocation is not available with the pysqa based backend."
Expand All @@ -109,10 +100,9 @@ def create_file_executor(
raise ValueError(
"The option to specify an init_function is not available with the pysqa based backend."
)
if flux_executor_pmi_mode is not None:
raise ValueError(
"The option to specify the flux pmi mode is not available with the pysqa based backend."
)
check_flux_executor_pmi_mode(flux_executor_pmi_mode=flux_executor_pmi_mode)
check_max_workers_and_cores(max_cores=max_cores, max_workers=max_workers)
check_hostname_localhost(hostname_localhost=hostname_localhost)
check_executor(executor=flux_executor)
check_nested_flux_executor(nested_flux_executor=flux_executor_nesting)
return FileExecutor(
Expand Down
34 changes: 34 additions & 0 deletions executorlib/standalone/inputcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,40 @@ def check_init_function(block_allocation: bool, init_function: Callable) -> None
raise ValueError("")


def check_max_workers_and_cores(max_workers: int, max_cores: int) -> None:
if max_workers != 1:
raise ValueError(
"The number of workers cannot be controlled with the pysqa based backend."
)
if max_cores != 1:
raise ValueError(
"The number of cores cannot be controlled with the pysqa based backend."
)

def check_hostname_localhost(hostname_localhost: Optional[bool]) -> None:
if hostname_localhost is not None:
raise ValueError(
"The option to connect to hosts based on their hostname is not available with the pysqa based backend."
)


def check_flux_executor_pmi_mode(flux_executor_pmi_mode: Optional[str]) -> None:
if flux_executor_pmi_mode is not None:
raise ValueError(
"The option to specify the flux pmi mode is not available with the pysqa based backend."
)


def check_pysqa_config_directory(pysqa_config_directory: Optional[str]) -> None:
"""
Check if pysqa_config_directory is None and raise a ValueError if it is not.
"""
if pysqa_config_directory is not None:
raise ValueError(
"pysqa_config_directory parameter is only supported for pysqa backend."
)


def validate_number_of_cores(max_cores: int, max_workers: int) -> int:
"""
Validate the number of cores and return the appropriate value.
Expand Down
26 changes: 26 additions & 0 deletions tests/test_shared_input_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
check_refresh_rate,
check_resource_dict,
check_resource_dict_is_empty,
check_flux_executor_pmi_mode,
check_max_workers_and_cores,
check_hostname_localhost,
check_pysqa_config_directory,
)


Expand Down Expand Up @@ -69,3 +73,25 @@ def test_check_nested_flux_executor(self):
def test_check_plot_dependency_graph(self):
with self.assertRaises(ValueError):
check_plot_dependency_graph(plot_dependency_graph=True)

def test_check_flux_executor_pmi_mode(self):
with self.assertRaises(ValueError):
check_flux_executor_pmi_mode(flux_executor_pmi_mode="test")

def test_check_max_workers_and_cores(self):
with self.assertRaises(ValueError):
check_max_workers_and_cores(max_workers=2, max_cores=1)
with self.assertRaises(ValueError):
check_max_workers_and_cores(max_workers=1, max_cores=2)
with self.assertRaises(ValueError):
check_max_workers_and_cores(max_workers=2, max_cores=2)

def test_check_hostname_localhost(self):
with self.assertRaises(ValueError):
check_hostname_localhost(hostname_localhost=True)
with self.assertRaises(ValueError):
check_hostname_localhost(hostname_localhost=True)

def test_check_pysqa_config_directory(self):
with self.assertRaises(ValueError):
check_pysqa_config_directory(pysqa_config_directory="path/to/config")

0 comments on commit 29d68fe

Please sign in to comment.