-
Notifications
You must be signed in to change notification settings - Fork 31
Hyperparameter Search
The grid_search
function in Wrench is based on Optuna with finer control over the search process!
Some basic concepts for easy reading:
-
grid: a grid is a specific configuration of parameters. For example, given a search space as
{'a': [1, 2, 3], 'b': [1, 2, 3]}
, a grid could be{'a':1, 'b':2}
. - run: a run is a process where the model is trained and evaluated given a grid.
- trial: a trial consists of multiple runs and the average test value will be returned.
import numpy as np
from wrench.dataset import load_dataset
from wrench.search import grid_search
from wrenchlabelmodel import Snorkel
#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)
#### Specify the hyper-parameter search space for grid search
search_space = {
'Snorkel': {
'lr': np.logspace(-5, -1, num=5, base=10),
'n_epochs': [5, 10, 50, 100, 200],
}
}
# Specify the total number of trials, it's ok to exceed the number of possible grids
# since the search would be terminated when all the grids are explored
n_trials = 100
# Specify the number of repeat runs within each trial
n_repeats = 5
#### Search best hyper-parameters using validation set in parallel
label_model = Snorkel()
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
metric='acc', direction='auto', search_space=search_space[label_model_name],
n_repeats=n_repeats, n_trials=n_trials, parallel=True)
Optuna provides trial-level parallelism: each process handles an independent trial and interacts with each other based on a database. If you don't have a database for parallelism, then some grids may be repeatedly explored by multiple processes.
Instead, we provide run-level parallelism. Basically, we start n_repeats
processes, each handle a single run within a trial. We found in-trial parallelism is good enough in most use cases of training machine learning models.
The grids will be shuffled before searching starts, which is important when the budget n_trials
is less than the total number of grids. If we do not shuffle the grids, the search result will be highly biased!
For example, given a search space: {'a': [1, 2, 3], 'b': [1, 2, 3]}
and the n_trials=3
. Optuna may lead to an undesired sequence of explored grids being like {'a':1, 'b':1}
, {'a':1, 'b':2}
, {'a':1, 'b':3}
.
Sometimes we may want to filter out invalid grids. We could feed grid_search
with a filter_fn
to do that.
For example, given a search space: {'a': [1, 2, 3], 'b': [1, 2, 3]}
, we want to make sure a + b > 2
. Then we write a filter function as below and feed it to grid_search
.
def customized_filter_fn(grids, para_names):
a, b = para_names.index('a'), para_names.index('b')
return [grid for grid in grids if grid[a] + grid[b] > 2]
grid_search(**, filter_fn=customized_filter_fn)
grid_search
inputs a callable argument process_fn
, which initializes, fits and tests the model. The default process_fn
looks like:
def single_process(item, model, dataset_train, y_train, dataset_valid, y_valid, metric, direction, kwargs):
suggestions, i = item
kwargs = kwargs.copy()
hyperparas = model.hyperparas
m = model.__class__(**hyperparas)
m.fit(dataset_train=dataset_train, y_train=y_train, dataset_valid=dataset_valid, y_valid=y_valid,
verbose=False, metric=metric, direction=direction, **suggestions, **kwargs)
value = m.test(dataset_valid, metric_fn=metric)
return value
What if we want to handle runs within a trial differently, for example, each run with a different seed? We can do this by inputting a customized process_fn
:
def single_process_with_seed(item, model, dataset_train, y_train, dataset_valid, y_valid, metric, direction, kwargs):
suggestions, i = item
kwargs = kwargs.copy()
seeds = kwargs.pop('seeds')
seed = seeds[i]
hyperparas = model.hyperparas
m = model.__class__(**hyperparas)
m.fit(dataset_train=dataset_train, y_train=y_train, dataset_valid=dataset_valid, y_valid=y_valid,
verbose=False, metric=metric, direction=direction, seed=seed, **suggestions, **kwargs)
value = m.test(dataset_valid, metric_fn=metric)
return value
grid_search(**, seeds=[1, 2, 3])
Optuna provides timeout for the whole search process, however, we found sometimes it's important to kill a trial that takes too long. Therefore, grid_search
provides a trial_timeout
argument. If it's larger than 0, then a trial will be terminated after trial_timeout
seconds. Note that this feature does not work for parallel search!
grid_search
provides a study_patience
argument, it allows to stop the search when the metric value does not improve for study_patience
consecutive trials.
This feature is always coupled with the min_trials
, to guarantee a minimum number of searched trials.
Both study_patience
and min_trials
could be either float or int, if float, it's k% of the number of possible grids (not n_trials!).
grid_search
provides a prune_threshold
argument, it allows to prune one trial when results returned by the first run is less promising, i.e., (best value - current value) > prune_threshold
* best value.
For models implemented in Wrench, users could get the default search space by
from wrench.search_space import get_search_space
search_space, filter_fn = get_search_space('Snorkel')
A searchable model should be inherited from BaseModel
. For example,
from wrench.basemodel import BaseModel
class NewModel(BaseModel):
def __init__(self, lr = 0.01, a = 0.5, b = 0.5):
super().__init__()
self.hyperparas = {
'lr': lr,
'a': a,
'b': b,
}
def fit(self, dataset_train, y_train=None, dataset_valid=None, y_valid=None, verbose=False, **kwargs):
self._update_hyperparas(**kwargs)
pass
def test(self, dataset, metric_fn, y_true=None, **kwargs):
pass
-
__init__
: the__init__
function inputs default parameters or parameters you want to fix during search (could be empty if default value is set). It does not initialize the model, instead it only initializes the parameters and store them intoself.hyperparas
. -
fit
: beforefit
starts, it excutesself._update_hyperparas(**kwargs)
to update theself.hyperparas
with input. This step is critical for hyper-parameter search! -
test
: thegrid_search
will calltest
function to evaluate model based on valid dataset. Ify_true
is None, the true labels should be already stored in thedataset
.