Skip to content

Commit

Permalink
Add experiments.mean_teacher & utils.mean_teacher
Browse files Browse the repository at this point in the history
  • Loading branch information
ctrl-q committed Aug 29, 2019
1 parent e9b33d2 commit 6646fdc
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 0 deletions.
107 changes: 107 additions & 0 deletions experiments/mean_teacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Mean Teach Experiment
Requires `trainset` , `train_labeled_set` and `validset` variables, of `type torch.utils.data.Dataset`
"""
from .base import PyTorchExperiment
from torch.optim.adam import Adam
from models.resnet import ResNet

from utils.mean_teacher import mt_dataloaders
from models.mean_teacher import MeanTeacherModel
from skopt.optimizer.optimizer import Optimizer

from . import parser, parse_args, mark_best
import os
from ast import literal_eval as le

NO_LABEL = -1
FOLDER = "mean-teacher"
os.makedirs(FOLDER, exist_ok=True)

if __name__ == "__main__":

parser.add_argument("--clas", type=le, default=0.68,
help="Classification correction")
parser.add_argument("--cons", type=le, default=141,
help="consistency correction")
parser.add_argument("--eps", type=le, default=9,
help="epsilon for VAT Loss")
parser.add_argument("--lbs", type=le, default=1, help="Labeled batch size")
parser.add_argument("--wd", type=le, default=0.00864,
help="Weight decay for Adam")

[
(
datapath,
max_epoch,
lr,
max_patience,
batch_size,
trials
),
args
] = parse_args(parser)

[
classification_correction,
consistency_correction,
vat_epsilon,
labeled_batch_size,
weight_decay
] = [value for _, value in sorted(args.items())]

opt = Optimizer((
max_epoch,
lr,
max_patience,
batch_size,
classification_correction,
consistency_correction,
vat_epsilon,
labeled_batch_size,
weight_decay
))

results = []

try:
for _ in range(trials):
# Get hyperparameters
params = opt.ask(strategy='cl_max')
[
max_epoch,
lr,
max_patience,
batch_size,
classification_correction,
consistency_correction,
vat_epsilon,
labeled_batch_size,
weight_decay
] = params
# Setting up model
model = MeanTeacherModel(
ResNet(18, 17), classification_correction, consistency_correction, vat_epsilon)

# Setting up datasets and dataloaders
dataloaders = mt_dataloaders(trainset, train_labeled_set, validset, int(
batch_size), int(labeled_batch_size))

optimizer = Adam(model._model.parameters(),
lr=lr, weight_decay=weight_decay)

file_name = "-".join([
str(param if not isinstance(param, float) else round(param, 2))
for param in params
])
path = os.path.join(FOLDER, file_name + ".tar")
exp = PyTorchExperiment(path, model, optimizer)

exp.run(dataloaders, max_epoch, patience=max_patience)
exp.load("best")
score = exp.score

opt.tell(params, score)
results.append((params, score))

finally:
mark_best(results, FOLDER)
108 changes: 108 additions & 0 deletions utils/mean_teacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Utility functions/classes for the mean teacher model
"""
import itertools
import numpy as np
from torch.utils.data.sampler import Sampler
from torch.utils.data import DataLoader

NO_LABEL = -1


def mt_dataloaders(unlabeled_dataset, train_dataset,
valid_dataset, batch_size=100, labeled_batch_size=10):
# Need to lace training dataset with unlabeled examples
"""
Data Loader for the mean teacher
Args:
unlabeled_dataset: The unlabeled dataset to be loaded
train_dataset: the train dataset
valid_dataset: the valid dataset
batch_size: the batch size to use for the unlabeled dataset
labeled_batch_size: the batch size to use for the labeled dataset
Returns:
dataloaders: a dictionary of data loaders for the train and validation dataset
"""
concat_train_dataset = train_dataset + unlabeled_dataset

n_train_labeled = len(train_dataset)
labeled_idx = list(range(n_train_labeled))
unlabeled_idx = list(range(n_train_labeled, len(concat_train_dataset)))

batch_sampler = TwoStreamBatchSampler(
unlabeled_idx, labeled_idx, batch_size, labeled_batch_size)

train_loader = DataLoader(concat_train_dataset,
batch_sampler=batch_sampler,
num_workers=0,
pin_memory=True)

valid_loader = DataLoader(valid_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=False)

dataloaders = {
"training": train_loader,
"validation": valid_loader
}
return dataloaders


class TwoStreamBatchSampler(Sampler):
"""Iterate over two sets of indices
An 'epoch' is one iteration through the primary indices.
During the epoch, the secondary indices are iterated through
as many times as needed.
"""

def __init__(self, primary_indices, secondary_indices,
batch_size, secondary_batch_size):
"""
Args.:
primary_indices: the indices of the unlabeled data points,
secondary_indices: the indices of the labeled data points,
batch_size: batch size for an iteration over the unlabeled data points,
secondary_batch_size: batch size for an iteration over the labeled data points,
"""
self.primary_indices = primary_indices
self.secondary_indices = secondary_indices
self.secondary_batch_size = secondary_batch_size
self.primary_batch_size = batch_size - secondary_batch_size

assert len(self.primary_indices) >= self.primary_batch_size > 0
assert len(self.secondary_indices) >= self.secondary_batch_size > 0

def __iter__(self):
primary_iter = self.iterate_once(self.primary_indices)
secondary_iter = self.iterate_eternally(self.secondary_indices)
return (
primary_batch + secondary_batch
for (primary_batch, secondary_batch)
in zip(self.grouper(primary_iter, self.primary_batch_size),
self.grouper(secondary_iter, self.secondary_batch_size))
)

def __len__(self):
return len(self.primary_indices) // self.primary_batch_size

def iterate_once(self, iterable):
return np.random.permutation(iterable)

def iterate_eternally(self, indices):
def infinite_shuffles():
while True:
yield np.random.permutation(indices)
return itertools.chain.from_iterable(infinite_shuffles())

def grouper(self, iterable, n):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3) --> ABC DEF"
args = [iter(iterable)] * n
return zip(*args)

0 comments on commit 6646fdc

Please sign in to comment.