diff --git a/experiments/mean_teacher.py b/experiments/mean_teacher.py new file mode 100644 index 0000000..74d1e17 --- /dev/null +++ b/experiments/mean_teacher.py @@ -0,0 +1,107 @@ +"""Mean Teach Experiment +Requires `trainset` , `train_labeled_set` and `validset` variables, of `type torch.utils.data.Dataset` +""" +from .base import PyTorchExperiment +from torch.optim.adam import Adam +from models.resnet import ResNet + +from utils.mean_teacher import mt_dataloaders +from models.mean_teacher import MeanTeacherModel +from skopt.optimizer.optimizer import Optimizer + +from . import parser, parse_args, mark_best +import os +from ast import literal_eval as le + +NO_LABEL = -1 +FOLDER = "mean-teacher" +os.makedirs(FOLDER, exist_ok=True) + +if __name__ == "__main__": + + parser.add_argument("--clas", type=le, default=0.68, + help="Classification correction") + parser.add_argument("--cons", type=le, default=141, + help="consistency correction") + parser.add_argument("--eps", type=le, default=9, + help="epsilon for VAT Loss") + parser.add_argument("--lbs", type=le, default=1, help="Labeled batch size") + parser.add_argument("--wd", type=le, default=0.00864, + help="Weight decay for Adam") + + [ + ( + datapath, + max_epoch, + lr, + max_patience, + batch_size, + trials + ), + args + ] = parse_args(parser) + + [ + classification_correction, + consistency_correction, + vat_epsilon, + labeled_batch_size, + weight_decay + ] = [value for _, value in sorted(args.items())] + + opt = Optimizer(( + max_epoch, + lr, + max_patience, + batch_size, + classification_correction, + consistency_correction, + vat_epsilon, + labeled_batch_size, + weight_decay + )) + + results = [] + + try: + for _ in range(trials): + # Get hyperparameters + params = opt.ask(strategy='cl_max') + [ + max_epoch, + lr, + max_patience, + batch_size, + classification_correction, + consistency_correction, + vat_epsilon, + labeled_batch_size, + weight_decay + ] = params + # Setting up model + model = MeanTeacherModel( + ResNet(18, 17), classification_correction, consistency_correction, vat_epsilon) + + # Setting up datasets and dataloaders + dataloaders = mt_dataloaders(trainset, train_labeled_set, validset, int( + batch_size), int(labeled_batch_size)) + + optimizer = Adam(model._model.parameters(), + lr=lr, weight_decay=weight_decay) + + file_name = "-".join([ + str(param if not isinstance(param, float) else round(param, 2)) + for param in params + ]) + path = os.path.join(FOLDER, file_name + ".tar") + exp = PyTorchExperiment(path, model, optimizer) + + exp.run(dataloaders, max_epoch, patience=max_patience) + exp.load("best") + score = exp.score + + opt.tell(params, score) + results.append((params, score)) + + finally: + mark_best(results, FOLDER) diff --git a/utils/mean_teacher.py b/utils/mean_teacher.py new file mode 100644 index 0000000..9e09060 --- /dev/null +++ b/utils/mean_teacher.py @@ -0,0 +1,108 @@ +""" +Utility functions/classes for the mean teacher model +""" +import itertools +import numpy as np +from torch.utils.data.sampler import Sampler +from torch.utils.data import DataLoader + +NO_LABEL = -1 + + +def mt_dataloaders(unlabeled_dataset, train_dataset, + valid_dataset, batch_size=100, labeled_batch_size=10): + # Need to lace training dataset with unlabeled examples + """ + Data Loader for the mean teacher + + Args: + unlabeled_dataset: The unlabeled dataset to be loaded + train_dataset: the train dataset + valid_dataset: the valid dataset + batch_size: the batch size to use for the unlabeled dataset + labeled_batch_size: the batch size to use for the labeled dataset + + Returns: + dataloaders: a dictionary of data loaders for the train and validation dataset + """ + concat_train_dataset = train_dataset + unlabeled_dataset + + n_train_labeled = len(train_dataset) + labeled_idx = list(range(n_train_labeled)) + unlabeled_idx = list(range(n_train_labeled, len(concat_train_dataset))) + + batch_sampler = TwoStreamBatchSampler( + unlabeled_idx, labeled_idx, batch_size, labeled_batch_size) + + train_loader = DataLoader(concat_train_dataset, + batch_sampler=batch_sampler, + num_workers=0, + pin_memory=True) + + valid_loader = DataLoader(valid_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + drop_last=False) + + dataloaders = { + "training": train_loader, + "validation": valid_loader + } + return dataloaders + + +class TwoStreamBatchSampler(Sampler): + """Iterate over two sets of indices + + An 'epoch' is one iteration through the primary indices. + During the epoch, the secondary indices are iterated through + as many times as needed. + """ + + def __init__(self, primary_indices, secondary_indices, + batch_size, secondary_batch_size): + """ + + Args.: + primary_indices: the indices of the unlabeled data points, + secondary_indices: the indices of the labeled data points, + batch_size: batch size for an iteration over the unlabeled data points, + secondary_batch_size: batch size for an iteration over the labeled data points, + """ + self.primary_indices = primary_indices + self.secondary_indices = secondary_indices + self.secondary_batch_size = secondary_batch_size + self.primary_batch_size = batch_size - secondary_batch_size + + assert len(self.primary_indices) >= self.primary_batch_size > 0 + assert len(self.secondary_indices) >= self.secondary_batch_size > 0 + + def __iter__(self): + primary_iter = self.iterate_once(self.primary_indices) + secondary_iter = self.iterate_eternally(self.secondary_indices) + return ( + primary_batch + secondary_batch + for (primary_batch, secondary_batch) + in zip(self.grouper(primary_iter, self.primary_batch_size), + self.grouper(secondary_iter, self.secondary_batch_size)) + ) + + def __len__(self): + return len(self.primary_indices) // self.primary_batch_size + + def iterate_once(self, iterable): + return np.random.permutation(iterable) + + def iterate_eternally(self, indices): + def infinite_shuffles(): + while True: + yield np.random.permutation(indices) + return itertools.chain.from_iterable(infinite_shuffles()) + + def grouper(self, iterable, n): + "Collect data into fixed-length chunks or blocks" + # grouper('ABCDEFG', 3) --> ABC DEF" + args = [iter(iterable)] * n + return zip(*args)