diff --git a/docs/index.md b/docs/index.md index 6f9af75..1c7f700 100644 --- a/docs/index.md +++ b/docs/index.md @@ -25,3 +25,22 @@ The package can be installed from [PyPi](https://pypi.org/project/parameterspace ```bash pip install parameterspace ``` + +## ConfigSpace Compatibility + +In case you are used to working with +[ConfigSpace](https://github.com/automl/ConfigSpace/) or for other reasons have space +definitions in the `ConfigSpace` format around, you can convert them into +`ParameterSpace` instances with ease. +Just note that any colons `:` in the `ConfigSpace` parameter names will be converted to +underscores `_`. + +```python +import json +from parameterspace.configspace_utils import parameterspace_from_configspace_dict + +with open("config_space.json", "r") as fh: + cs = json.load(fh) + +ps = parameterspace_from_configspace_dict(cs) +``` diff --git a/parameterspace/configspace_utils.py b/parameterspace/configspace_utils.py new file mode 100644 index 0000000..67ad032 --- /dev/null +++ b/parameterspace/configspace_utils.py @@ -0,0 +1,206 @@ +"""Initialize a `ParameterSpace` from a `ConfigSpace` JSON dictionary.""" + +from typing import List, Optional, Tuple + +import numpy as np + +import parameterspace as ps +from parameterspace.condition import Condition +from parameterspace.utils import verify_lambda + + +def _escape_parameter_name(name: str) -> str: + """Replace colons with underscores. + + Colons are incompatible as ParameterSpace parameter names. + """ + return name.replace(":", "_") + + +def _get_condition( + conditions: List[dict], configspace_parameter_name: str +) -> Optional[Condition]: + """Construct a lambda function that can be used as a ParameterSpace condition from a + ConfigSpace conditions list given a specific target parameter name. + + NOTE: The `configspace_parameter_name` here needs to match the original name in + `ConfigSpace`, not the one transformed with `_escape_parameter_name`. + """ + condition = Condition() + + varnames = [] + function_texts = [] + for cond in conditions: + if cond["child"] == configspace_parameter_name: + parent = _escape_parameter_name(cond["parent"]) + varnames.append(parent) + # The representation is used because it quotes strings. + if cond["type"] == "IN": + function_texts.append(f"{parent} in {tuple(cond['values'])}") + elif cond["type"] == "EQ": + function_texts.append(f"{parent} == {repr(cond['value'])}") + elif cond["type"] == "NEQ": + function_texts.append(f"{parent} != {repr(cond['value'])}") + elif cond["type"] == "GT": + function_texts.append(f"{parent} > {repr(cond['value'])}") + elif cond["type"] == "LT": + function_texts.append(f"{parent} < {repr(cond['value'])}") + else: + raise NotImplementedError(f"Unsupported condition type {cond['type']}") + + if not varnames: + return condition + + function_text = " and ".join(function_texts) + verify_lambda(variables=varnames, body=function_text) + # pylint: disable=eval-used + condition_function = eval(f"lambda {', '.join(varnames)}: {function_text}") + # pylint: enable=eval-used + + condition.function_texts.append(function_text) + condition.varnames.append(varnames) + condition.all_varnames |= set(varnames) + condition.functions.append(condition_function) + + return condition + + +def _convert_for_normal_parameter( + log: bool, lower: Optional[float], upper: Optional[float], mu: float, sigma: float +) -> Tuple[float, float, float, float]: + """Convert bounds and prior mean/std from `ConfigSpace` parameter dictionary with + normal prior to `ParameterSpace` compatible values. + + Args: + log: Are we on a log scale? + lower: Optional lower bound in the original space (required when `log=True`) + upper: Optional upper bound in the original space (required when `log=True`) + mu: Mean of the `ConfigSpace` normal distribution + sigma: Standard deviation of the `ConfigSpace` normal distribution + + Returns: + Transformed lower bound, upper bound, mean and standard deviation + + Raises: + Value error when log is True but bounds are missing. + """ + if lower is None or upper is None: + if log: + raise ValueError( + "Please provide bounds, when using a log transform with a normal prior." + ) + lower = mu - 4 * sigma + upper = mu + 4 * sigma + + if log: + log_upper, log_lower = np.log(upper), np.log(lower) + log_interval_size = log_upper - log_lower + mean = (mu - log_lower) / log_interval_size + std = sigma / log_interval_size + else: + interval_size = upper - lower + mean = (mu - lower) / interval_size + std = sigma / interval_size + + return lower, upper, mean, std + + +def parameterspace_from_configspace_dict(configspace_dict: dict) -> ps.ParameterSpace: + """Create `ParameterSpace` instance from a `ConfigSpace` JSON dictionary. + + Note, that `ParameterSpace` does not support regular, non-truncated normal priors + and will thus translate an unbounded normal prior to a normal truncated at +/- 4 + sigma. Also, constant parameters are represented as categoricals with a single value + that are fixed to said value. + + Args: + configspace_dict: The dictionary based on a `ConfigSpace` JSON representation. + + Returns: + A `ParameterSpace` instance. + + Raises: + NotImplementedError in case a given parameter type or configuration is not + supported. + """ + space = ps.ParameterSpace() + + for param_dict in configspace_dict["hyperparameters"]: + param_name = _escape_parameter_name(param_dict["name"]) + condition = _get_condition(configspace_dict["conditions"], param_dict["name"]) + if param_dict["type"] == "uniform_int": + space._parameters[param_name] = { + "parameter": ps.IntegerParameter( + name=param_name, + bounds=(param_dict["lower"], param_dict["upper"]), + transformation="log" if param_dict["log"] else None, + ), + "condition": condition, + } + + elif param_dict["type"] == "categorical": + space._parameters[param_name] = { + "parameter": ps.CategoricalParameter( + name=param_name, + values=param_dict["choices"], + prior=param_dict.get("weights", None), + ), + "condition": condition, + } + + elif param_dict["type"] in ["constant", "unparametrized"]: + space._parameters[param_name] = { + "parameter": ps.CategoricalParameter( + name=param_name, + values=[param_dict["value"]], + ), + "condition": condition, + } + space.fix(**{param_name: param_dict["value"]}) + + elif param_dict["type"] in ["normal_float", "normal_int"]: + parameter_class = ( + ps.ContinuousParameter + if param_dict["type"] == "normal_float" + else ps.IntegerParameter + ) + lower_bound, upper_bound, mean, std = _convert_for_normal_parameter( + log=param_dict["log"], + lower=param_dict.get("lower", None), + upper=param_dict.get("upper", None), + mu=param_dict["mu"], + sigma=param_dict["sigma"], + ) + space._parameters[param_name] = { + "parameter": parameter_class( + name=param_name, + bounds=(lower_bound, upper_bound), + prior=ps.priors.TruncatedNormal(mean=mean, std=std), + transformation="log" if param_dict["log"] else None, + ), + "condition": condition, + } + + elif param_dict["type"] == "uniform_float": + space._parameters[param_name] = { + "parameter": ps.ContinuousParameter( + name=param_name, + bounds=(param_dict["lower"], param_dict["upper"]), + transformation="log" if param_dict["log"] else None, + ), + "condition": condition, + } + + elif param_dict["type"] == "ordinal": + space._parameters[param_name] = { + "parameter": ps.OrdinalParameter( + name=param_name, + values=param_dict["sequence"], + ), + "condition": condition, + } + + else: + raise NotImplementedError(f"Unsupported type {param_dict['type']}") + + return space diff --git a/poetry.lock b/poetry.lock index bc9881b..447f05d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -29,6 +29,7 @@ python-versions = "*" [package.dependencies] six = ">=1.6.1,<2.0" +wheel = ">=0.23.0,<1.0" [package.source] type = "legacy" @@ -47,7 +48,7 @@ python-versions = ">=3.5" dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"] docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] -tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] +tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] [package.source] type = "legacy" @@ -186,7 +187,7 @@ optional = false python-versions = ">=3.6.0" [package.extras] -unicode_backport = ["unicodedata2"] +unicode-backport = ["unicodedata2"] [package.source] type = "legacy" @@ -222,6 +223,29 @@ type = "legacy" url = "https://pypi.python.org/simple" reference = "public-pypi" +[[package]] +name = "configspace" +version = "0.6.0" +description = "Creation and manipulation of parameter configuration spaces for automated algorithm configuration and hyperparameter tuning." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +cython = "*" +numpy = "*" +pyparsing = "*" +scipy = "*" +typing_extensions = "*" + +[package.extras] +dev = ["automl_sphinx_theme (>=0.1.11)", "mypy", "pre-commit", "pytest (>=4.6)", "pytest-cov"] + +[package.source] +type = "legacy" +url = "https://pypi.python.org/simple" +reference = "public-pypi" + [[package]] name = "coverage" version = "6.5.0" @@ -241,6 +265,19 @@ type = "legacy" url = "https://pypi.python.org/simple" reference = "public-pypi" +[[package]] +name = "cython" +version = "0.29.32" +description = "The Cython compiler for writing C extensions for the Python language." +category = "dev" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" + +[package.source] +type = "legacy" +url = "https://pypi.python.org/simple" +reference = "public-pypi" + [[package]] name = "defusedxml" version = "0.7.1" @@ -457,9 +494,9 @@ python-versions = ">=3.6.1,<4.0" [package.extras] colors = ["colorama (>=0.4.3,<0.5.0)"] -pipfile_deprecated_finder = ["pipreqs", "requirementslib"] +pipfile-deprecated-finder = ["pipreqs", "requirementslib"] plugins = ["setuptools"] -requirements_deprecated_finder = ["pip-api", "pipreqs"] +requirements-deprecated-finder = ["pip-api", "pipreqs"] [package.source] type = "legacy" @@ -615,7 +652,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*" [package.extras] cssselect = ["cssselect (>=0.7)"] html5 = ["html5lib"] -htmlsoup = ["beautifulsoup4"] +htmlsoup = ["BeautifulSoup4"] source = ["Cython (>=0.29.7)"] [package.source] @@ -655,12 +692,12 @@ mdurl = ">=0.1,<1.0" [package.extras] benchmarking = ["psutil", "pytest", "pytest-benchmark (>=3.2,<4.0)"] -code_style = ["pre-commit (==2.6)"] +code-style = ["pre-commit (==2.6)"] compare = ["commonmark (>=0.9.1,<0.10.0)", "markdown (>=3.3.6,<3.4.0)", "mistletoe (>=0.8.1,<0.9.0)", "mistune (>=2.0.2,<2.1.0)", "panflute (>=2.1.3,<2.2.0)"] linkify = ["linkify-it-py (>=1.0,<2.0)"] plugins = ["mdit-py-plugins"] profiling = ["gprof2dot"] -rtd = ["attrs", "myst-parser", "pyyaml", "sphinx", "sphinx-book-theme", "sphinx-copybutton", "sphinx-design"] +rtd = ["attrs", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] [package.source] @@ -706,7 +743,7 @@ python-versions = ">=3.7" markdown-it-py = ">=1.0.0,<3.0.0" [package.extras] -code_style = ["pre-commit"] +code-style = ["pre-commit"] rtd = ["attrs", "myst-parser (>=0.16.1,<0.17.0)", "sphinx-book-theme (>=0.1.0,<0.2.0)"] testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] @@ -1114,6 +1151,9 @@ category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +[package.dependencies] +setuptools = "*" + [package.source] type = "legacy" url = "https://pypi.python.org/simple" @@ -1444,6 +1484,7 @@ python-versions = ">= 3.6" [package.dependencies] pytest = ">=5.3" +setuptools = ">=40.0" [package.source] type = "legacy" @@ -1560,7 +1601,7 @@ urllib3 = ">=1.21.1,<1.27" [package.extras] socks = ["PySocks (>=1.5.6,!=1.5.7)"] -use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [package.source] type = "legacy" @@ -1579,7 +1620,7 @@ python-versions = ">=3.8" numpy = ">=1.18.5,<1.26.0" [package.extras] -dev = ["flake8", "mypy", "pycodestyle", "typing-extensions"] +dev = ["flake8", "mypy", "pycodestyle", "typing_extensions"] doc = ["matplotlib (>2)", "numpydoc", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-panels (>=0.5.2)", "sphinx-tabs"] test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] @@ -1588,6 +1629,24 @@ type = "legacy" url = "https://pypi.python.org/simple" reference = "public-pypi" +[[package]] +name = "setuptools" +version = "65.5.1" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] + +[package.source] +type = "legacy" +url = "https://pypi.python.org/simple" +reference = "public-pypi" + [[package]] name = "six" version = "1.16.0" @@ -1639,7 +1698,7 @@ python-versions = ">=3.7" webencodings = ">=0.4" [package.extras] -doc = ["sphinx", "sphinx-rtd-theme"] +doc = ["sphinx", "sphinx_rtd_theme"] test = ["flake8", "isort", "pytest"] [package.source] @@ -1830,6 +1889,22 @@ type = "legacy" url = "https://pypi.python.org/simple" reference = "public-pypi" +[[package]] +name = "wheel" +version = "0.38.4" +description = "A built-package format for Python" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.extras] +test = ["pytest (>=3.0.0)"] + +[package.source] +type = "legacy" +url = "https://pypi.python.org/simple" +reference = "public-pypi" + [[package]] name = "wrapt" version = "1.14.1" @@ -1866,7 +1941,7 @@ examples = [] [metadata] lock-version = "1.1" python-versions = ">=3.8,<4.0" -content-hash = "4d3b59a7371431b689a2873f81f439337661950a4a83aee902a6338e8d395be7" +content-hash = "5ca4931d82946c59592819b7172c051fb0ceeabb6db475cb80f7f57057d90904" [metadata.files] astroid = [ @@ -2004,6 +2079,34 @@ colorama = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +configspace = [ + {file = "ConfigSpace-0.6.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5aa865a30f11ae72779611f758835e8dd699546d748278fa984c9b10ce32a2fc"}, + {file = "ConfigSpace-0.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b42cfec268dac1b036cea2165effde4be5171b9ed71d7ebf75945b533ebab973"}, + {file = "ConfigSpace-0.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7279ad0571a04447feadd6598072a0b37893e1a536a062d1e8f249f7ecdbf3f0"}, + {file = "ConfigSpace-0.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:888a3d54def64b49f6bf4f96238a60d7415a1c0b5d93caf0050509766c22e7a5"}, + {file = "ConfigSpace-0.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:1c79e529672c307aec7d497a2f242b546c2155f69117649840efc9213dade5ab"}, + {file = "ConfigSpace-0.6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:810658b06d986c8c042ea5967b0f1cde8855ab94edd68a16fff4fff4d34f2b04"}, + {file = "ConfigSpace-0.6.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:701f7181094bc300670d954033ce040de0af9d3ecdaafcd4bf5c4744e0fc087d"}, + {file = "ConfigSpace-0.6.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d410c3509fd8b444c0015e4f3c0f8e3c704bbf50bead11d6149ba8168f452efb"}, + {file = "ConfigSpace-0.6.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a6ef29b8823649ea74ebceaaee260c6caf68882b7f6b77d998108e0203bd8e5"}, + {file = "ConfigSpace-0.6.0-cp37-cp37m-win32.whl", hash = "sha256:34b731dcbcbb915f910792c1ee4e11bb8c9bfb8071054bf4d99edccb6b91931c"}, + {file = "ConfigSpace-0.6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:47dd301c2818d789d41048d382a2da1b19f0952be6b3bec8a6353759a9218897"}, + {file = "ConfigSpace-0.6.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f8228e86809d2cb53aa9db0dc77bd26199cf409dff4270831023a274af77e15f"}, + {file = "ConfigSpace-0.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6a1a785bd1e34be6085a0359e62be11a0ff291c3dfce749d13f8d38c747af962"}, + {file = "ConfigSpace-0.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4c599c89cdbb1277fbe4fdb2358406e764de597fe264a239cfc98c68f51f3ca"}, + {file = "ConfigSpace-0.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2644f50028934f66e8c4fb2dd08ba50393c10fc5aefd456526df51652b5946e7"}, + {file = "ConfigSpace-0.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c2e6ed3c788686808b9e86035dbabfa2fd2bea9498fd7207080a8629d05e248a"}, + {file = "ConfigSpace-0.6.0-cp38-cp38-win32.whl", hash = "sha256:c1316a86ad05343909360302abf681bea131f7079e641821645fad47e43d25f6"}, + {file = "ConfigSpace-0.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:fa29cd617c5bbb19eb94dc25a305f124c492cfd34d4de22312547e390c7e7a47"}, + {file = "ConfigSpace-0.6.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e0634d2c467674a39ae265103ae1ef89abda5065f5905081ba942c313731e006"}, + {file = "ConfigSpace-0.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4a7a333ab453e5fdb4a45a25d4fb76c7135cfec12c31b0ff59fbec148843a56d"}, + {file = "ConfigSpace-0.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad61a54477e0c621d8c910805544a0a8f4dd257ca8c3d18d5bd42eef2bc10bfe"}, + {file = "ConfigSpace-0.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ddb5fd8e5184b4954abd35186bb0ff612ff5a838faa46b433320193b4bae4aa"}, + {file = "ConfigSpace-0.6.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4260935511e8340b72372124ced984c8d6426f26425cd479328be71e3cbfcdbf"}, + {file = "ConfigSpace-0.6.0-cp39-cp39-win32.whl", hash = "sha256:71a034fa61743efc6fc785af32f7362f2e24ba5b646409e3069194e46b103b78"}, + {file = "ConfigSpace-0.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:34c63cbd56446e58acc7132b60a24f19c614583a3f3862ef91b6015079f531fd"}, + {file = "ConfigSpace-0.6.0.tar.gz", hash = "sha256:9b6c95d8839fcab220372673214b3129b45dcd8b1179829eb2c65746cacb72a9"}, +] coverage = [ {file = "coverage-6.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ef8674b0ee8cc11e2d574e3e2998aea5df5ab242e012286824ea3c6970580e53"}, {file = "coverage-6.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:784f53ebc9f3fd0e2a3f6a78b2be1bd1f5575d7863e10c6e12504f240fd06660"}, @@ -2056,6 +2159,48 @@ coverage = [ {file = "coverage-6.5.0-pp36.pp37.pp38-none-any.whl", hash = "sha256:1431986dac3923c5945271f169f59c45b8802a114c8f548d611f2015133df77a"}, {file = "coverage-6.5.0.tar.gz", hash = "sha256:f642e90754ee3e06b0e7e51bce3379590e76b7f76b708e1a71ff043f87025c84"}, ] +cython = [ + {file = "Cython-0.29.32-cp27-cp27m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:39afb4679b8c6bf7ccb15b24025568f4f9b4d7f9bf3cbd981021f542acecd75b"}, + {file = "Cython-0.29.32-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:dbee03b8d42dca924e6aa057b836a064c769ddfd2a4c2919e65da2c8a362d528"}, + {file = "Cython-0.29.32-cp27-cp27mu-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5ba622326f2862f9c1f99ca8d47ade49871241920a352c917e16861e25b0e5c3"}, + {file = "Cython-0.29.32-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e6ffa08aa1c111a1ebcbd1cf4afaaec120bc0bbdec3f2545f8bb7d3e8e77a1cd"}, + {file = "Cython-0.29.32-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:97335b2cd4acebf30d14e2855d882de83ad838491a09be2011745579ac975833"}, + {file = "Cython-0.29.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:06be83490c906b6429b4389e13487a26254ccaad2eef6f3d4ee21d8d3a4aaa2b"}, + {file = "Cython-0.29.32-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:eefd2b9a5f38ded8d859fe96cc28d7d06e098dc3f677e7adbafda4dcdd4a461c"}, + {file = "Cython-0.29.32-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5514f3b4122cb22317122a48e175a7194e18e1803ca555c4c959d7dfe68eaf98"}, + {file = "Cython-0.29.32-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:656dc5ff1d269de4d11ee8542f2ffd15ab466c447c1f10e5b8aba6f561967276"}, + {file = "Cython-0.29.32-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:cdf10af3e2e3279dc09fdc5f95deaa624850a53913f30350ceee824dc14fc1a6"}, + {file = "Cython-0.29.32-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:3875c2b2ea752816a4d7ae59d45bb546e7c4c79093c83e3ba7f4d9051dd02928"}, + {file = "Cython-0.29.32-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:79e3bab19cf1b021b613567c22eb18b76c0c547b9bc3903881a07bfd9e7e64cf"}, + {file = "Cython-0.29.32-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b0595aee62809ba353cebc5c7978e0e443760c3e882e2c7672c73ffe46383673"}, + {file = "Cython-0.29.32-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:0ea8267fc373a2c5064ad77d8ff7bf0ea8b88f7407098ff51829381f8ec1d5d9"}, + {file = "Cython-0.29.32-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:c8e8025f496b5acb6ba95da2fb3e9dacffc97d9a92711aacfdd42f9c5927e094"}, + {file = "Cython-0.29.32-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:afbce249133a830f121b917f8c9404a44f2950e0e4f5d1e68f043da4c2e9f457"}, + {file = "Cython-0.29.32-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:513e9707407608ac0d306c8b09d55a28be23ea4152cbd356ceaec0f32ef08d65"}, + {file = "Cython-0.29.32-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e83228e0994497900af954adcac27f64c9a57cd70a9ec768ab0cb2c01fd15cf1"}, + {file = "Cython-0.29.32-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:ea1dcc07bfb37367b639415333cfbfe4a93c3be340edf1db10964bc27d42ed64"}, + {file = "Cython-0.29.32-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:8669cadeb26d9a58a5e6b8ce34d2c8986cc3b5c0bfa77eda6ceb471596cb2ec3"}, + {file = "Cython-0.29.32-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:ed087eeb88a8cf96c60fb76c5c3b5fb87188adee5e179f89ec9ad9a43c0c54b3"}, + {file = "Cython-0.29.32-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:3f85eb2343d20d91a4ea9cf14e5748092b376a64b7e07fc224e85b2753e9070b"}, + {file = "Cython-0.29.32-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:63b79d9e1f7c4d1f498ab1322156a0d7dc1b6004bf981a8abda3f66800e140cd"}, + {file = "Cython-0.29.32-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e1958e0227a4a6a2c06fd6e35b7469de50adf174102454db397cec6e1403cce3"}, + {file = "Cython-0.29.32-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:856d2fec682b3f31583719cb6925c6cdbb9aa30f03122bcc45c65c8b6f515754"}, + {file = "Cython-0.29.32-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:479690d2892ca56d34812fe6ab8f58e4b2e0129140f3d94518f15993c40553da"}, + {file = "Cython-0.29.32-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:67fdd2f652f8d4840042e2d2d91e15636ba2bcdcd92e7e5ffbc68e6ef633a754"}, + {file = "Cython-0.29.32-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:4a4b03ab483271f69221c3210f7cde0dcc456749ecf8243b95bc7a701e5677e0"}, + {file = "Cython-0.29.32-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:40eff7aa26e91cf108fd740ffd4daf49f39b2fdffadabc7292b4b7dc5df879f0"}, + {file = "Cython-0.29.32-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0bbc27abdf6aebfa1bce34cd92bd403070356f28b0ecb3198ff8a182791d58b9"}, + {file = "Cython-0.29.32-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cddc47ec746a08603037731f5d10aebf770ced08666100bd2cdcaf06a85d4d1b"}, + {file = "Cython-0.29.32-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca3065a1279456e81c615211d025ea11bfe4e19f0c5650b859868ca04b3fcbd"}, + {file = "Cython-0.29.32-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:d968ffc403d92addf20b68924d95428d523436adfd25cf505d427ed7ba3bee8b"}, + {file = "Cython-0.29.32-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:f3fd44cc362eee8ae569025f070d56208908916794b6ab21e139cea56470a2b3"}, + {file = "Cython-0.29.32-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:b6da3063c5c476f5311fd76854abae6c315f1513ef7d7904deed2e774623bbb9"}, + {file = "Cython-0.29.32-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:061e25151c38f2361bc790d3bcf7f9d9828a0b6a4d5afa56fbed3bd33fb2373a"}, + {file = "Cython-0.29.32-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:f9944013588a3543fca795fffb0a070a31a243aa4f2d212f118aa95e69485831"}, + {file = "Cython-0.29.32-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:07d173d3289415bb496e72cb0ddd609961be08fe2968c39094d5712ffb78672b"}, + {file = "Cython-0.29.32-py2.py3-none-any.whl", hash = "sha256:eeb475eb6f0ccf6c039035eb4f0f928eb53ead88777e0a760eccb140ad90930b"}, + {file = "Cython-0.29.32.tar.gz", hash = "sha256:8733cf4758b79304f2a4e39ebfac5e92341bce47bcceb26c1254398b2f8c1af7"}, +] defusedxml = [ {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, @@ -2690,6 +2835,10 @@ scipy = [ {file = "scipy-1.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:5b88e6d91ad9d59478fafe92a7c757d00c59e3bdc3331be8ada76a4f8d683f58"}, {file = "scipy-1.9.3.tar.gz", hash = "sha256:fbc5c05c85c1a02be77b1ff591087c83bc44579c6d2bd9fb798bb64ea5e1a027"}, ] +setuptools = [ + {file = "setuptools-65.5.1-py3-none-any.whl", hash = "sha256:d0b9a8433464d5800cbe05094acf5c6d52a91bfac9b52bcfc4d41382be5d5d31"}, + {file = "setuptools-65.5.1.tar.gz", hash = "sha256:e197a19aa8ec9722928f2206f8de752def0e4c9fc6953527360d1c36d94ddb2f"}, +] six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, @@ -2786,6 +2935,10 @@ webencodings = [ {file = "webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78"}, {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"}, ] +wheel = [ + {file = "wheel-0.38.4-py3-none-any.whl", hash = "sha256:b60533f3f5d530e971d6737ca6d58681ee434818fab630c83a734bb10c083ce8"}, + {file = "wheel-0.38.4.tar.gz", hash = "sha256:965f5259b566725405b05e7cf774052044b1ed30119b5d586b2703aafe8719ac"}, +] wrapt = [ {file = "wrapt-1.14.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3"}, {file = "wrapt-1.14.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef"}, diff --git a/pyproject.toml b/pyproject.toml index 1ee7a68..9795aea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "parameterspace" -version = "0.7.21" +version = "0.7.22" description = "Parametrized hierarchical spaces with flexible priors and transformations." readme = "README.md" repository = "https://github.com/boschresearch/parameterspace" @@ -47,6 +47,9 @@ pydocstyle = "^6.1.1" [tool.poetry.extras] examples = ["notebook", "matplotlib"] +[tool.poetry.group.dev.dependencies] +configspace = "^0.6.0" + [tool.pytest.ini_options] filterwarnings = ["error::DeprecationWarning", "error::PendingDeprecationWarning"] markers = ["integration_test: Execute API calls."] diff --git a/tests/test_configspace_utils.py b/tests/test_configspace_utils.py new file mode 100644 index 0000000..5e4e436 --- /dev/null +++ b/tests/test_configspace_utils.py @@ -0,0 +1,395 @@ +import json + +import numpy as np +import pytest +from ConfigSpace import ( + ConfigurationSpace, + EqualsCondition, + Float, + GreaterThanCondition, + InCondition, + LessThanCondition, + Normal, + NotEqualsCondition, +) +from ConfigSpace.read_and_write import json as cs_json +from scipy.stats import truncnorm as scipy_truncnorm + +from parameterspace.configspace_utils import parameterspace_from_configspace_dict +from parameterspace.priors.categorical import Categorical as CategoricalPrior +from parameterspace.priors.truncated_normal import ( + TruncatedNormal as TruncatedNormalPrior, +) +from parameterspace.transformations.log_zero_one import ( + LogZeroOneInteger as LogZeroOneIntegerTransformation, +) + +CS_CONDITIONS_JSON = """{ + "hyperparameters": [ + { + "name": "alpha", + "type": "uniform_float", + "log": true, + "lower": 0.001, + "upper": 1095.0, + "default": 1.0 + }, + { + "name": "booster", + "type": "categorical", + "choices": [ + "gblinear", + "gbtree", + "dart" + ], + "default": "gblinear", + "probabilities": null + }, + { + "name": "lambda", + "type": "uniform_float", + "log": true, + "lower": 0.0009118819655545162, + "upper": 1096.6331584284585, + "default": 1.0 + }, + { + "name": "nrounds", + "type": "uniform_int", + "log": true, + "lower": 8, + "upper": 2980, + "default": 122 + }, + { + "name": "repl", + "type": "uniform_int", + "log": false, + "lower": 1, + "upper": 10, + "default": 6 + }, + { + "name": "max_depth", + "type": "uniform_int", + "log": false, + "lower": 3, + "upper": 10, + "default": 3 + }, + { + "name": "rate_drop", + "type": "uniform_float", + "log": false, + "lower": 0.0, + "upper": 1.0, + "default": 0.2 + } + ], + "conditions": [ + { + "child": "max_depth", + "parent": "booster", + "type": "IN", + "values": [ + "dart", + "gbtree" + ] + }, + { + "child": "rate_drop", + "parent": "booster", + "type": "EQ", + "value": "dart" + } + ], + "forbiddens": [], + "python_module_version": "0.4.18", + "json_format_version": 0.2 +}""" + + +def _cs_to_dict(cs: ConfigurationSpace) -> dict: + return json.loads(cs_json.write(cs)) + + +def test_conditions_and_log_transform(): + cs_dict = json.loads(CS_CONDITIONS_JSON) + space = parameterspace_from_configspace_dict(cs_dict) + assert len(space) == 7 + assert space.has_conditions() + assert space.check_validity( + { + "alpha": 1.0, + "booster": "gbtree", + "lambda": 1.0, + "nrounds": 122, + "repl": 6, + "max_depth": 3, + } + ) + assert isinstance( + space._parameters["nrounds"]["parameter"]._transformation, + LogZeroOneIntegerTransformation, + ) + + _s = space.copy() + _s.fix(booster="dart") + assert len(_s.sample()) == 7 + + _s = space.copy() + _s.fix(booster="gblinear") + assert len(_s.sample()) == 5 + + alpha = cs_dict["hyperparameters"][0] + assert alpha["name"] == "alpha" + bounds = [alpha["lower"], alpha["upper"]] + assert list(space._parameters["alpha"]["parameter"].bounds) == bounds + + booster = cs_dict["hyperparameters"][1] + assert booster["name"] == "booster" + assert space._parameters["booster"]["parameter"].values == booster["choices"] + + +def test_continuous_with_normal_prior(): + cs_dict = json.loads( + """{ + "name": "myspace", + "hyperparameters": [ + { + "name": "p", + "type": "normal_float", + "log": false, + "mu": 8.0, + "sigma": 10.0, + "default": 6.2 + } + ], + "conditions": [], + "forbiddens": [], + "python_module_version": "0.6.0", + "json_format_version": 0.4 + } + """ + ) + space = parameterspace_from_configspace_dict(cs_dict) + assert len(space) == 1 + + param = space.get_parameter_by_name("p")["parameter"] + assert isinstance(param._prior, TruncatedNormalPrior) + + samples = np.array([space.sample()["p"] for _ in range(10_000)]) + assert abs(samples.mean() - cs_dict["hyperparameters"][0]["mu"]) < 0.1 + assert abs(samples.std() - cs_dict["hyperparameters"][0]["sigma"]) < 0.2 + + +def test_continuous_with_log_normal_prior_and_no_bounds_raises(): + cs = ConfigurationSpace( + space={ + "p": Float( + "p", + default=0.1, + log=True, + distribution=Normal(1.0, 0.6), + ), + }, + ) + with pytest.raises(ValueError): + parameterspace_from_configspace_dict(_cs_to_dict(cs)) + + +def test_continuous_with_log_normal_prior(): + mu = 1.0 + sigma = 0.6 + cs_dict = json.loads( + f"""{{ + "name": "myspace", + "hyperparameters": [ + {{ + "name": "p", + "type": "normal_float", + "lower": 1e-5, + "upper": 1e-1, + "log": true, + "mu": {mu}, + "sigma": {sigma}, + "default": 1.1 + }} + ], + "conditions": [], + "forbiddens": [], + "python_module_version": "0.6.0", + "json_format_version": 0.4 + }} + """ + ) + space = parameterspace_from_configspace_dict(cs_dict) + assert len(space) == 1 + + param = space.get_parameter_by_name("p")["parameter"] + assert isinstance(param._prior, TruncatedNormalPrior) + + samples = np.array([space.sample()["p"] for _ in range(10_000)]) + + a, b = (np.log(param.bounds) - mu) / sigma + expected_mean = scipy_truncnorm.stats(a, b, loc=mu, scale=sigma, moments="m") + assert abs(np.log(samples).mean() - expected_mean) < 0.1 + + expected_var = scipy_truncnorm.stats(a, b, loc=mu, scale=sigma, moments="v") + assert abs(np.log(samples).var() - expected_var) < 0.1 + + +def test_integer_with_normal_prior(): + cs_dict = json.loads( + """{ + "name": "myspace", + "hyperparameters": [ + { + "name": "p", + "type": "normal_int", + "log": false, + "mu": 8.0, + "sigma": 5.0, + "default": 2 + } + ], + "conditions": [], + "forbiddens": [], + "python_module_version": "0.6.0", + "json_format_version": 0.4 + } + """ + ) + space = parameterspace_from_configspace_dict(cs_dict) + assert len(space) == 1 + + param = space.get_parameter_by_name("p")["parameter"] + assert isinstance(param._prior, TruncatedNormalPrior) + + samples = np.array([space.sample()["p"] for _ in range(10_000)]) + + assert abs(samples.mean() - cs_dict["hyperparameters"][0]["mu"]) < 0.1 + assert abs(samples.std() - cs_dict["hyperparameters"][0]["sigma"]) < 0.3 + + +def test_categorical_with_custom_probabilities(): + cs_dict = json.loads( + """{ + "name": "myspace", + "hyperparameters": [ + { + "name": "c", + "type": "categorical", + "choices": [ + "red", + "green", + "blue" + ], + "default": "blue", + "weights": [ + 2, + 1, + 1 + ] + } + ], + "conditions": [], + "forbiddens": [], + "python_module_version": "0.6.0", + "json_format_version": 0.4 + }""" + ) + space = parameterspace_from_configspace_dict(cs_dict) + assert len(space) == 1 + + param = space.get_parameter_by_name("c")["parameter"] + assert isinstance(param._prior, CategoricalPrior) + reference_weights = np.array(cs_dict["hyperparameters"][0]["weights"]) + assert np.all( + param._prior.probabilities == reference_weights / reference_weights.sum() + ) + + +def test_equals_condition(): + cs = ConfigurationSpace({"a": [1, 2, 3], "b": (1.0, 8.0)}) + cond = EqualsCondition(cs["b"], cs["a"], 1) + cs.add_condition(cond) + + space = parameterspace_from_configspace_dict(_cs_to_dict(cs)) + assert len(space) == 2 + + _s = space.copy() + _s.fix(a=1) + assert len(_s.sample()) == 2 + + _s = space.copy() + _s.fix(a=2) + assert len(_s.sample()) == 1 + + +def test_not_equals_condition(): + cs = ConfigurationSpace({"a": [1, 2, 3], "b": (1.0, 8.0)}) + cond = NotEqualsCondition(cs["b"], cs["a"], 1) + cs.add_condition(cond) + + space = parameterspace_from_configspace_dict(_cs_to_dict(cs)) + assert len(space) == 2 + + _s = space.copy() + _s.fix(a=2) + assert len(_s.sample()) == 2 + + _s = space.copy() + _s.fix(a=1) + assert len(_s.sample()) == 1 + + +def test_less_than_condition(): + cs = ConfigurationSpace({"a": (0, 10), "b": (1.0, 8.0)}) + cond = LessThanCondition(cs["b"], cs["a"], 5) + cs.add_condition(cond) + + space = parameterspace_from_configspace_dict(_cs_to_dict(cs)) + assert len(space) == 2 + + _s = space.copy() + _s.fix(a=4) + assert len(_s.sample()) == 2 + + _s = space.copy() + _s.fix(a=6) + assert len(_s.sample()) == 1 + + +def test_greater_than_condition(): + cs = ConfigurationSpace({"a": (0, 10), "b": (1.0, 8.0)}) + cond = GreaterThanCondition(cs["b"], cs["a"], 5) + cs.add_condition(cond) + + space = parameterspace_from_configspace_dict(_cs_to_dict(cs)) + assert len(space) == 2 + + _s = space.copy() + _s.fix(a=6) + assert len(_s.sample()) == 2 + + _s = space.copy() + _s.fix(a=4) + assert len(_s.sample()) == 1 + + +def test_in_condition(): + cs = ConfigurationSpace({"a": (0, 10), "b": (1.0, 8.0)}) + cond = InCondition(cs["b"], cs["a"], [1, 2, 3, 4]) + cs.add_condition(cond) + + space = parameterspace_from_configspace_dict(_cs_to_dict(cs)) + assert len(space) == 2 + + _s = space.copy() + _s.fix(a=2) + assert len(_s.sample()) == 2 + + _s = space.copy() + _s.fix(a=5) + assert len(_s.sample()) == 1