Skip to content

Hyperparameter Search

Jieyu Zhang edited this page Sep 5, 2021 · 4 revisions

The grid_search function in Wrench is based on Optuna with finer control over the search process!

Some basic concepts for easy reading:

  • grid: a grid is a specific configuration of parameters. For example, given a search space as {'a': [1, 2, 3], 'b': [1, 2, 3]}, a grid could be {'a':1, 'b':2}.
  • run: a run is a process where the model is trained and evaluated given a grid.
  • trial: a trial consists of multiple runs and the average test value will be returned.

Basic Usage

import numpy as np
from wrench.dataset import load_dataset
from wrench.search import grid_search
from wrenchlabelmodel import Snorkel

#### Load dataset
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)

#### Specify the hyper-parameter search space for grid search
search_space = {
    'Snorkel': {
        'lr': np.logspace(-5, -1, num=5, base=10),
        'n_epochs': [5, 10, 50, 100, 200],
    }
}

# Specify the total number of trials, it's ok to exceed the number of possible grids
# since the search would be terminated when all the grids are explored
n_trials = 100
# Specify the number of repeat runs within each trial
n_repeats = 5

#### Search best hyper-parameters using validation set in parallel
label_model = Snorkel()
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
                             metric='acc', direction='auto', search_space=search_space[label_model_name],
                             n_repeats=n_repeats, n_trials=n_trials, parallel=True)

What're the Differences to Optuna?

Run-level Parallelism

Optuna provides trial-level parallelism: each process handles an independent trial and interacts with each other based on a database. If you don't have a database for parallelism, then some grids may be repeatedly explored by multiple processes.

Instead, we provide run-level parallelism. Basically, we start n_repeats processes, each handle a single run within a trial. We found in-trial parallelism is good enough in most use cases of training machine learning models.

Shuffled Grids

The grids will be shuffled before searching starts, which is important when the budget n_trials is less than the total number of grids. If we do not shuffle the grids, the search result will be highly biased!

For example, given a search space: {'a': [1, 2, 3], 'b': [1, 2, 3]} and the n_trials=3. Optuna may lead to an undesired sequence of explored grids being like {'a':1, 'b':1}, {'a':1, 'b':2}, {'a':1, 'b':3}.

Filter Invalid Grids

Sometimes we may want to filter out invalid grids. We could feed grid_search with a filter_fn to do that.

For example, given a search space: {'a': [1, 2, 3], 'b': [1, 2, 3]}, we want to make sure a + b > 2. Then we write a filter function as below and feed it to grid_search.

def customized_filter_fn(grids, para_names):
    a, b = para_names.index('a'), para_names.index('b')
    return [grid for grid in grids if grid[a] + grid[b] > 2]

grid_search(**, filter_fn=customized_filter_fn)

Handle Different Runs within a Trial

grid_search inputs a callable argument process_fn, which initializes, fits and tests the model. The default process_fn looks like:

def single_process(item, model, dataset_train, y_train, dataset_valid, y_valid, metric, direction, kwargs):
    suggestions, i = item
    kwargs = kwargs.copy()
    hyperparas = model.hyperparas
    m = model.__class__(**hyperparas)
    m.fit(dataset_train=dataset_train, y_train=y_train, dataset_valid=dataset_valid, y_valid=y_valid,
          verbose=False, metric=metric, direction=direction, **suggestions, **kwargs)
    value = m.test(dataset_valid, metric_fn=metric)
    return value

What if we want to handle runs within a trial differently, for example, each run with a different seed? We can do this by inputting a customized process_fn:

def single_process_with_seed(item, model, dataset_train, y_train, dataset_valid, y_valid, metric, direction, kwargs):
    suggestions, i = item
    kwargs = kwargs.copy()
    seeds = kwargs.pop('seeds')
    seed = seeds[i]
    hyperparas = model.hyperparas
    m = model.__class__(**hyperparas)
    m.fit(dataset_train=dataset_train, y_train=y_train, dataset_valid=dataset_valid, y_valid=y_valid,
          verbose=False, metric=metric, direction=direction, seed=seed, **suggestions, **kwargs)
    value = m.test(dataset_valid, metric_fn=metric)
    return value

grid_search(**, seeds=[1, 2, 3])

Set Timeout for Each Run

Optuna provides timeout for the whole search process, however, we found sometimes it's important to kill a trial that takes too long. Therefore, grid_search provides a trial_timeout argument. If it's larger than 0, then a trial will be terminated after trial_timeout seconds. Note that this feature does not work for parallel search!

Get Default Search Space for Models

For models implemented in Wrench, users could get the default search space by

from wrench.search_space import get_search_space
search_space, filter_fn = get_search_space('Snorkel')

Build Searchable Model

A searchable model should be inherited from BaseModel. For example,

from wrench.basemodel import BaseModel

class NewModel(BaseModel):
    def __init__(self, lr = 0.01, a = 0.5, b = 0.5):
        super().__init__()
        self.hyperparas = {
            'lr': lr,
            'a': a,
            'b': b,
        }

    def fit(self, dataset_train, y_train=None, dataset_valid=None, y_valid=None, verbose=False, **kwargs):
        self._update_hyperparas(**kwargs)
        pass

    def test(self, dataset, metric_fn, y_true=None, **kwargs):
        pass
  • __init__: the __init__ function inputs default parameters or parameters you want to fix during search (could be empty if default value is set). It does not initialize the model, instead it only initializes the parameters and store them into self.hyperparas.

  • fit: before fit starts, it excutes self._update_hyperparas(**kwargs) to update the self.hyperparas with input. This step is critical for hyper-parameter search!

  • test: the grid_search will call test function to evaluate model based on valid dataset. If y_true is None, the true labels should be already stored in the dataset.

Clone this wiki locally