-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add experiments.mean_teacher & utils.mean_teacher
- Loading branch information
Showing
2 changed files
with
215 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
"""Mean Teach Experiment | ||
Requires `trainset` , `train_labeled_set` and `validset` variables, of `type torch.utils.data.Dataset` | ||
""" | ||
from .base import PyTorchExperiment | ||
from torch.optim.adam import Adam | ||
from models.resnet import ResNet | ||
|
||
from utils.mean_teacher import mt_dataloaders | ||
from models.mean_teacher import MeanTeacherModel | ||
from skopt.optimizer.optimizer import Optimizer | ||
|
||
from . import parser, parse_args, mark_best | ||
import os | ||
from ast import literal_eval as le | ||
|
||
NO_LABEL = -1 | ||
FOLDER = "mean-teacher" | ||
os.makedirs(FOLDER, exist_ok=True) | ||
|
||
if __name__ == "__main__": | ||
|
||
parser.add_argument("--clas", type=le, default=0.68, | ||
help="Classification correction") | ||
parser.add_argument("--cons", type=le, default=141, | ||
help="consistency correction") | ||
parser.add_argument("--eps", type=le, default=9, | ||
help="epsilon for VAT Loss") | ||
parser.add_argument("--lbs", type=le, default=1, help="Labeled batch size") | ||
parser.add_argument("--wd", type=le, default=0.00864, | ||
help="Weight decay for Adam") | ||
|
||
[ | ||
( | ||
datapath, | ||
max_epoch, | ||
lr, | ||
max_patience, | ||
batch_size, | ||
trials | ||
), | ||
args | ||
] = parse_args(parser) | ||
|
||
[ | ||
classification_correction, | ||
consistency_correction, | ||
vat_epsilon, | ||
labeled_batch_size, | ||
weight_decay | ||
] = [value for _, value in sorted(args.items())] | ||
|
||
opt = Optimizer(( | ||
max_epoch, | ||
lr, | ||
max_patience, | ||
batch_size, | ||
classification_correction, | ||
consistency_correction, | ||
vat_epsilon, | ||
labeled_batch_size, | ||
weight_decay | ||
)) | ||
|
||
results = [] | ||
|
||
try: | ||
for _ in range(trials): | ||
# Get hyperparameters | ||
params = opt.ask(strategy='cl_max') | ||
[ | ||
max_epoch, | ||
lr, | ||
max_patience, | ||
batch_size, | ||
classification_correction, | ||
consistency_correction, | ||
vat_epsilon, | ||
labeled_batch_size, | ||
weight_decay | ||
] = params | ||
# Setting up model | ||
model = MeanTeacherModel( | ||
ResNet(18, 17), classification_correction, consistency_correction, vat_epsilon) | ||
|
||
# Setting up datasets and dataloaders | ||
dataloaders = mt_dataloaders(trainset, train_labeled_set, validset, int( | ||
batch_size), int(labeled_batch_size)) | ||
|
||
optimizer = Adam(model._model.parameters(), | ||
lr=lr, weight_decay=weight_decay) | ||
|
||
file_name = "-".join([ | ||
str(param if not isinstance(param, float) else round(param, 2)) | ||
for param in params | ||
]) | ||
path = os.path.join(FOLDER, file_name + ".tar") | ||
exp = PyTorchExperiment(path, model, optimizer) | ||
|
||
exp.run(dataloaders, max_epoch, patience=max_patience) | ||
exp.load("best") | ||
score = exp.score | ||
|
||
opt.tell(params, score) | ||
results.append((params, score)) | ||
|
||
finally: | ||
mark_best(results, FOLDER) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
""" | ||
Utility functions/classes for the mean teacher model | ||
""" | ||
import itertools | ||
import numpy as np | ||
from torch.utils.data.sampler import Sampler | ||
from torch.utils.data import DataLoader | ||
|
||
NO_LABEL = -1 | ||
|
||
|
||
def mt_dataloaders(unlabeled_dataset, train_dataset, | ||
valid_dataset, batch_size=100, labeled_batch_size=10): | ||
# Need to lace training dataset with unlabeled examples | ||
""" | ||
Data Loader for the mean teacher | ||
Args: | ||
unlabeled_dataset: The unlabeled dataset to be loaded | ||
train_dataset: the train dataset | ||
valid_dataset: the valid dataset | ||
batch_size: the batch size to use for the unlabeled dataset | ||
labeled_batch_size: the batch size to use for the labeled dataset | ||
Returns: | ||
dataloaders: a dictionary of data loaders for the train and validation dataset | ||
""" | ||
concat_train_dataset = train_dataset + unlabeled_dataset | ||
|
||
n_train_labeled = len(train_dataset) | ||
labeled_idx = list(range(n_train_labeled)) | ||
unlabeled_idx = list(range(n_train_labeled, len(concat_train_dataset))) | ||
|
||
batch_sampler = TwoStreamBatchSampler( | ||
unlabeled_idx, labeled_idx, batch_size, labeled_batch_size) | ||
|
||
train_loader = DataLoader(concat_train_dataset, | ||
batch_sampler=batch_sampler, | ||
num_workers=0, | ||
pin_memory=True) | ||
|
||
valid_loader = DataLoader(valid_dataset, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=0, | ||
pin_memory=True, | ||
drop_last=False) | ||
|
||
dataloaders = { | ||
"training": train_loader, | ||
"validation": valid_loader | ||
} | ||
return dataloaders | ||
|
||
|
||
class TwoStreamBatchSampler(Sampler): | ||
"""Iterate over two sets of indices | ||
An 'epoch' is one iteration through the primary indices. | ||
During the epoch, the secondary indices are iterated through | ||
as many times as needed. | ||
""" | ||
|
||
def __init__(self, primary_indices, secondary_indices, | ||
batch_size, secondary_batch_size): | ||
""" | ||
Args.: | ||
primary_indices: the indices of the unlabeled data points, | ||
secondary_indices: the indices of the labeled data points, | ||
batch_size: batch size for an iteration over the unlabeled data points, | ||
secondary_batch_size: batch size for an iteration over the labeled data points, | ||
""" | ||
self.primary_indices = primary_indices | ||
self.secondary_indices = secondary_indices | ||
self.secondary_batch_size = secondary_batch_size | ||
self.primary_batch_size = batch_size - secondary_batch_size | ||
|
||
assert len(self.primary_indices) >= self.primary_batch_size > 0 | ||
assert len(self.secondary_indices) >= self.secondary_batch_size > 0 | ||
|
||
def __iter__(self): | ||
primary_iter = self.iterate_once(self.primary_indices) | ||
secondary_iter = self.iterate_eternally(self.secondary_indices) | ||
return ( | ||
primary_batch + secondary_batch | ||
for (primary_batch, secondary_batch) | ||
in zip(self.grouper(primary_iter, self.primary_batch_size), | ||
self.grouper(secondary_iter, self.secondary_batch_size)) | ||
) | ||
|
||
def __len__(self): | ||
return len(self.primary_indices) // self.primary_batch_size | ||
|
||
def iterate_once(self, iterable): | ||
return np.random.permutation(iterable) | ||
|
||
def iterate_eternally(self, indices): | ||
def infinite_shuffles(): | ||
while True: | ||
yield np.random.permutation(indices) | ||
return itertools.chain.from_iterable(infinite_shuffles()) | ||
|
||
def grouper(self, iterable, n): | ||
"Collect data into fixed-length chunks or blocks" | ||
# grouper('ABCDEFG', 3) --> ABC DEF" | ||
args = [iter(iterable)] * n | ||
return zip(*args) |