Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for overwriting parameters via CLI #79

Merged
merged 3 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions seml/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def add_configs(collection, seml_config, slurm_config, configs, source_files=Non
collection.insert_many(db_dicts)


def add_experiments(db_collection_name, config_file, force_duplicates, no_hash=False, no_sanity_check=False,
def add_experiments(db_collection_name, config_file, force_duplicates, overwrite_params=None, no_hash=False, no_sanity_check=False,
no_code_checkpoint=False):
"""
Add configurations from a config file into the database.
Expand All @@ -120,6 +120,7 @@ def add_experiments(db_collection_name, config_file, force_duplicates, no_hash=F
db_collection_name: the MongoDB collection name.
config_file: path to the YAML configuration.
force_duplicates: if True, disable duplicate detection.
overwrite_params: optional flat dictionary to overwrite parameters in all configs.
no_hash: if True, disable hashing of the configurations for duplicate detection. This is much slower, so use only
if you have a good reason to.
no_sanity_check: if True, do not check the config for missing/unused arguments.
Expand All @@ -136,7 +137,7 @@ def add_experiments(db_collection_name, config_file, force_duplicates, no_hash=F
if 'conda_environment' not in seml_config:
seml_config['conda_environment'] = os.environ.get('CONDA_DEFAULT_ENV')

# Set Slurm config with default parameters as fall-back option
# Set Slurm config with default parameters as fall-back option
slurm_config = merge_dicts(SETTINGS.SLURM_DEFAULT, slurm_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a related note: can we also make it possible to overwrite slurm parameters when starting experiments? Sometimes I screw up the partition or so and then I have to re-queue everything.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good idea. However, I think we should open a follow-up issue on that and have a separate PR since the logic will be distinct from this one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this use the same interface? e.g. --with slurm.sbatch_options.mem=25G or so?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use a separate argument and unify the style with seml jupyter where you have the -sb option to supply sbatch parameters. However, there you need to provide a dictionary which is not very user-friendly. In conjunction with this PR it could look something like:

seml <collection> add <yaml> -o dataset=imagenet -sb mem=25G

We could reuse the key-value parsing we introduce with this PR for the sbatch parameters but integrating both into one argument makes parsing more difficult and user always have to nest their parameters and sbatch options with sbatch.xyz and config.xyz.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that nesting with config. each time is not super user friendly. I suggest we do -sb mem=25G partition=gpu_all etc. separately (which is I think what you suggested)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes :)


# Check for and use sbatch options template
Expand All @@ -149,7 +150,7 @@ def add_experiments(db_collection_name, config_file, force_duplicates, no_hash=F
slurm_config['sbatch_options'][k] = v

slurm_config['sbatch_options'] = remove_prepended_dashes(slurm_config['sbatch_options'])
configs = generate_configs(experiment_config)
configs = generate_configs(experiment_config, overwrite_params=overwrite_params)
collection = get_collection(db_collection_name)

batch_id = get_max_in_collection(collection, "batch_id")
Expand Down
16 changes: 14 additions & 2 deletions seml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from seml.sources import import_exe
from seml.parameters import sample_random_configs, generate_grid, cartesian_product_dict
from seml.utils import merge_dicts, flatten, unflatten
from seml.utils import Hashabledict, merge_dicts, flatten, unflatten
from seml.errors import ConfigError, ExecutableError
from seml.settings import SETTINGS

Expand Down Expand Up @@ -133,7 +133,7 @@ def detect_duplicate_parameters(inverted_config: dict, sub_config_name: str = No
raise ConfigError(error_str.format(p1=p1, p2=p2))


def generate_configs(experiment_config):
def generate_configs(experiment_config, overwrite_params=None):
"""Generate parameter configurations based on an input configuration.

Input is a nested configuration where on each level there can be 'fixed', 'grid', and 'random' parameters.
Expand All @@ -155,6 +155,8 @@ def generate_configs(experiment_config):
experiment_config: dict
Dictionary that specifies the "search space" of parameters that will be enumerated. Should be
parsed from a YAML file.
overwrite_params: Optional[dict]
Flat dictionary that overwrites configs. Resulting duplicates will be removed.

Returns
-------
Expand Down Expand Up @@ -235,6 +237,16 @@ def generate_configs(experiment_config):
for k, v in config.items()}
for config in all_configs]

if overwrite_params is not None:
all_configs = [merge_dicts(config, overwrite_params) for config in all_configs]
base_length = len(all_configs)
# We use a dictionary instead a set because dictionary keys are ordered as of Python 3
all_configs = list({Hashabledict(**config): None for config in all_configs})
new_length = len(all_configs)
if base_length != new_length:
diff = base_length - new_length
logging.warn(f'Parameter overwrite caused {diff} identical configs. Duplicates were removed.')

all_configs = [unflatten(conf) for conf in all_configs]
return all_configs

Expand Down
16 changes: 16 additions & 0 deletions seml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ def parse_args(parser, commands):
return commands


class ParameterAction(argparse.Action):
def __init__(self, option_strings, dest, **kwargs):
super().__init__(option_strings, dest, **kwargs)

def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, {
value.split('=')[0]: eval('='.join(value.split('=')[1:]))
for value in values
})


def main():
parser = argparse.ArgumentParser(
description="Manage experiments for the given configuration. "
Expand Down Expand Up @@ -98,6 +109,11 @@ def main():
'-f', '--force-duplicates', action='store_true',
help="Add experiments to the database even when experiments with identical configurations "
"are already in the database.")
parser_add.add_argument(
'-o', '--overwrite-params', action=ParameterAction, nargs='+', default={},
help="Specifies parameters that overwrite their respective values in all configs."
"Format: <param>=<value>, use flat dictionary notation with key1.key2=value."
)
parser_add.set_defaults(func=add_experiments)

parser_start = subparsers.add_parser(
Expand Down
14 changes: 14 additions & 0 deletions seml/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterable
import logging
import json
import copy
Expand Down Expand Up @@ -254,3 +255,16 @@ def format(self, record):
log_fmt = self.FORMATS.get(record.levelno, self.FORMATS['DEFAULT'])
formatter = logging.Formatter(log_fmt)
return formatter.format(record)


class Hashabledict(dict):

def hashable_values(self):
for value in self.values():
if isinstance(value, Iterable):
yield tuple(value)
else:
yield value

def __hash__(self):
return hash((frozenset(self), frozenset(self.hashable_values())))
13 changes: 13 additions & 0 deletions test/resources/config/config_with_grid.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
grid:
dataset:
type: choice
options:
- small
- big

lr:
type: choice
options:
- 0.1
- 0.01

18 changes: 18 additions & 0 deletions test/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class TestParseConfigDicts(unittest.TestCase):
CONFIG_WITH_DUPLICATE_RDM_PARAMETERS_2 = "resources/config/config_with_duplicate_random_parameters_1.yaml"
CONFIG_WITH_ALL_TYPES = "resources/config/config_with_all_types.yaml"
CONFIG_WITH_EMPTY_DICT = "resources/config/config_with_empty_dictionary.yaml"
CONFIG_WITH_GRID = "resources/config/config_with_grid.yaml"

def load_config_dict(self, path):
with open(path, 'r') as conf:
Expand Down Expand Up @@ -80,6 +81,23 @@ def test_empty_dictionary(self):
}
}
self.assertEqual(configs, expected_config)

def test_overwrite_parameters(self):
config_dict = self.load_config_dict(self.CONFIG_WITH_GRID)
configs = config.generate_configs(config_dict, {
'dataset': 'small'
})
expected_configs = [
{
'dataset': 'small',
'lr': 0.1
},
{
'dataset': 'small',
'lr': 0.01
}
]
self.assertEqual(configs, expected_configs)

def test_duplicate_parameters(self):
config_dict = self.load_config_dict(self.CONFIG_WITH_DUPLICATE_PARAMETERS_1)
Expand Down