Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added src/sygnet/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added src/sygnet/__pycache__/dataloaders.cpython-38.pyc
Binary file not shown.
Binary file added src/sygnet/__pycache__/interface.cpython-38.pyc
Binary file not shown.
Binary file added src/sygnet/__pycache__/loader.cpython-38.pyc
Binary file not shown.
Binary file added src/sygnet/__pycache__/models.cpython-38.pyc
Binary file not shown.
Binary file added src/sygnet/__pycache__/requirements.cpython-38.pyc
Binary file not shown.
Binary file added src/sygnet/__pycache__/train.cpython-38.pyc
Binary file not shown.
Binary file added src/sygnet/__pycache__/tune.cpython-38.pyc
Binary file not shown.
12 changes: 10 additions & 2 deletions src/sygnet/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,23 @@ def tune(
"""

logger.warning(
"This function is still in development. Only 'wgan' modelling has been implemented thus far, and all hyperparameter searches will use random sampling rather than an exhaustive grid seach"
"This function is still in development. Only 'wgan' and 'cgan' modelling has been implemented thus far, and all hyperparameter searches will use random sampling rather than an exhaustive grid seach"
)

torch.manual_seed(seed)
random.seed(seed)

if mode != "wgan":
if mode not in ["wgan", "cgan"]:
return None

if mode == "cgan" and 'cond_cols' not in fit_opts:
logger.warning(
"Since you are using Conditional arcgitecture, you need to specify conditional columns as 'cond_cols' in fit_opts dictionary. Example: fit_opts = {'save_model': False, 'cond_cols' : ['name']}"
)

if mode == "cgan" and not isinstance(fit_opts['cond_cols'],list):
raise Exception("Conditional columns 'cond_cols' must be a list!")

if type(parameter_dict) is not dict:
logger.error("`parameter_dict` must be a dictionary with hyperparameter arguments as keys and lists of options to try as values. \n \
Tunable hyperparameters across sygnet are currently: \n \
Expand Down