Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/simclr_cifar.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ train:
warmup_lr: 0
base_lr: 0.3
final_lr: 0
num_epochs: 800 # this parameter influence the lr decay
num_epochs: 200 # this parameter influence the lr decay
stop_at_epoch: 100 # has to be smaller than num_epochs
batch_size: 256
batch_size: 512
knn_monitor: False # knn monitor will take more time
knn_interval: 1
knn_k: 200
Expand Down
6 changes: 1 addition & 5 deletions configs/simsiam_cifar.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,4 @@ logger:
seed: null # None type for yaml file
# two things might lead to stochastic behavior other than seed:
# worker_init_fn from dataloader and torch.nn.functional.interpolate
# (keep this in mind if you want to achieve 100% deterministic)




# (keep this in mind if you want to achieve 100% deterministic)
4 changes: 0 additions & 4 deletions configs/simsiam_cifar_eval_sgd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,3 @@ seed: null # None type for yaml file
# two things might lead to stochastic behavior other than seed:
# worker_init_fn from dataloader and torch.nn.functional.interpolate
# (keep this in mind if you want to achieve 100% deterministic)




37 changes: 37 additions & 0 deletions configs/simsiam_image100_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: simsiam-imagenet100-experiment-resnet50
dataset:
name: imagenet100
image_size: 224
num_workers: 4

model:
name: simsiam
backbone: resnet50
proj_layers: 2

train: null

eval: # linear evaluation, False will turn off automatic evaluation after training
optimizer:
name: sgd
weight_decay: 0
momentum: 0.9
warmup_lr: 0
warmup_epochs: 0
base_lr: 10 #30
final_lr: 0
batch_size: 128
num_epochs: 60

logger:
tensorboard: False
matplotlib: False

seed: null # None type for yaml file
# two things might lead to stochastic behavior other than seed:
# worker_init_fn from dataloader and torch.nn.functional.interpolate
# (keep this in mind if you want to achieve 100% deterministic)




46 changes: 46 additions & 0 deletions configs/simsiam_imagenet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: simsiam-imagenet-experiment-resnet50
dataset:
name: imagenet
image_size: 224
num_workers: 16

model:
name: simsiam
backbone: resnet50
proj_layers: 2

train:
optimizer:
name: sgd
weight_decay: 0.0002
momentum: 0.9
warmup_epochs: 0
warmup_lr: 0
base_lr: 0.05
final_lr: 0
num_epochs: 200 # this parameter influence the lr decay
stop_at_epoch: 200 # has to be smaller than num_epochs
batch_size: 512
knn_monitor: False # knn monitor will take more time
knn_interval: 50
knn_k: 200
eval: # linear evaluation, False will turn off automatic evaluation after training
optimizer:
name: sgd
weight_decay: 0
momentum: 0.9
warmup_lr: 0
warmup_epochs: 0
base_lr: 30
final_lr: 0
batch_size: 256
num_epochs: 100

logger:
tensorboard: True
matplotlib: True

seed: null # None type for yaml file
# two things might lead to stochastic behavior other than seed:
# worker_init_fn from dataloader and torch.nn.functional.interpolate
# (keep this in mind if you want to achieve 100% deterministic)
50 changes: 50 additions & 0 deletions configs/simsiam_imagenet100.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: simsiam-imagenet100-experiment-resnet50
dataset:
name: imagenet100
image_size: 224
num_workers: 8

model:
name: simsiam
backbone: resnet50
proj_layers: 2

train:
optimizer:
name: sgd
weight_decay: 0.0001
momentum: 0.9
warmup_epochs: 10
warmup_lr: 0
base_lr: 0.05
final_lr: 0
num_epochs: 200 # this parameter influence the lr decay
stop_at_epoch: 200 # has to be smaller than num_epochs
batch_size: 512
knn_monitor: True # knn monitor will take more time
knn_interval: 40
knn_k: 200
eval: # linear evaluation, False will turn off automatic evaluation after training
optimizer:
name: sgd
weight_decay: 0
momentum: 0.9
warmup_lr: 0
warmup_epochs: 0
base_lr: 30
final_lr: 0
batch_size: 256
num_epochs: 100

logger:
tensorboard: True
matplotlib: True

seed: null # None type for yaml file
# two things might lead to stochastic behavior other than seed:
# worker_init_fn from dataloader and torch.nn.functional.interpolate
# (keep this in mind if you want to achieve 100% deterministic)




37 changes: 37 additions & 0 deletions configs/simsiam_imagenet_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: simsiam-imagenet100-experiment-resnet50
dataset:
name: imagenet
image_size: 224
num_workers: 8

model:
name: simsiam
backbone: resnet50
proj_layers: 2

train: null

eval: # linear evaluation, False will turn off automatic evaluation after training
optimizer:
name: sgd
weight_decay: 0
momentum: 0.9
warmup_lr: 0
warmup_epochs: 0
base_lr: 30
final_lr: 0
batch_size: 256
num_epochs: 100

logger:
tensorboard: False
matplotlib: False

seed: null # None type for yaml file
# two things might lead to stochastic behavior other than seed:
# worker_init_fn from dataloader and torch.nn.functional.interpolate
# (keep this in mind if you want to achieve 100% deterministic)




14 changes: 10 additions & 4 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
from .random_dataset import RandomDataset


def get_dataset(dataset, data_dir, transform, train=True, download=False, debug_subset_size=None):
def get_dataset(dataset, data_dir, transform, train=True, download=True, debug_subset_size=None):
if dataset == 'mnist':
dataset = torchvision.datasets.MNIST(data_dir, train=train, transform=transform, download=download)
elif dataset == 'stl10':
dataset = torchvision.datasets.STL10(data_dir, split='train+unlabeled' if train else 'test', transform=transform, download=download)
elif dataset == 'cifar10':
dataset = torchvision.datasets.CIFAR10(data_dir, train=train, transform=transform, download=download)
dataset = torchvision.datasets.CIFAR10(data_dir, train=train, transform=transform, download=True)
elif dataset == 'cifar100':
dataset = torchvision.datasets.CIFAR100(data_dir, train=train, transform=transform, download=download)
elif dataset == 'imagenet':
dataset = torchvision.datasets.ImageNet(data_dir, split='train' if train == True else 'val', transform=transform, download=download)
elif dataset == 'imagenet' and train == True:
dataset = torchvision.datasets.ImageFolder(data_dir+'train', transform=transform)
elif dataset == 'imagenet' and train == False:
dataset = torchvision.datasets.ImageFolder(data_dir+'val', transform=transform)
elif dataset == 'imagenet100' and train == True:
dataset = torchvision.datasets.ImageFolder(data_dir+'train', transform=transform)
elif dataset == 'imagenet100' and train == False:
dataset = torchvision.datasets.ImageFolder(data_dir+'val', transform=transform)
elif dataset == 'random':
dataset = RandomDataset()
else:
Expand Down
111 changes: 49 additions & 62 deletions linear_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,114 +11,101 @@
from datasets import get_dataset
from optimizers import get_optimizer, LR_Scheduler

def main(args):

train_loader = torch.utils.data.DataLoader(
dataset=get_dataset(
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

def main(gpu, args):
rank = args.nr * args.gpus + gpu
dist.init_process_group("nccl", rank=rank, world_size=args.world_size)
torch.manual_seed(0)
torch.cuda.set_device(gpu)
train_dataset = get_dataset(
transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs),
train=True,
**args.dataset_kwargs
),
batch_size=args.eval.batch_size,
shuffle=True,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=args.world_size, rank=rank)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=(args.eval.batch_size//args.gpus),
shuffle=False,
sampler = train_sampler,
**args.dataloader_kwargs
)
test_loader = torch.utils.data.DataLoader(
dataset=get_dataset(
test_dataset = get_dataset(
transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
train=False,
**args.dataset_kwargs
),
batch_size=args.eval.batch_size,
)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=(args.eval.batch_size//args.gpus),
shuffle=False,
**args.dataloader_kwargs
)


model = get_backbone(args.model.backbone)
classifier = nn.Linear(in_features=model.output_dim, out_features=10, bias=True).to(args.device)

model = get_backbone(args.model.backbone)
classifier = nn.Linear(in_features=model.output_dim, out_features=100, bias=True).to(args.device)
assert args.eval_from is not None
save_dict = torch.load(args.eval_from, map_location='cpu')
msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)

# print(msg)
model = model.to(args.device)
model = torch.nn.DataParallel(model)

# if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
classifier = torch.nn.DataParallel(classifier)
model = model.to(args.device)
model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
classifier = DDP(classifier, device_ids=[gpu], find_unused_parameters=True)
# define optimizer
optimizer = get_optimizer(
args.eval.optimizer.name, classifier,
lr=args.eval.base_lr*args.eval.batch_size/256,
momentum=args.eval.optimizer.momentum,
weight_decay=args.eval.optimizer.weight_decay)

# define lr scheduler
lr_scheduler = LR_Scheduler(
optimizer,
args.eval.warmup_epochs, args.eval.warmup_lr*args.eval.batch_size/256,
args.eval.num_epochs, args.eval.base_lr*args.eval.batch_size/256, args.eval.final_lr*args.eval.batch_size/256,
len(train_loader),
)

loss_meter = AverageMeter(name='Loss')
acc_meter = AverageMeter(name='Accuracy')

# Start training
global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
for epoch in global_progress:
loss_meter.reset()
model.eval()
classifier.train()
local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=True)

for idx, (images, labels) in enumerate(local_progress):

classifier.zero_grad()
with torch.no_grad():
feature = model(images.to(args.device))

preds = classifier(feature)

loss = F.cross_entropy(preds, labels.to(args.device))

loss.backward()
optimizer.step()
loss_meter.update(loss.item())
lr = lr_scheduler.step()
local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg})

classifier.eval()
correct, total = 0, 0
acc_meter.reset()
for idx, (images, labels) in enumerate(test_loader):
with torch.no_grad():
feature = model(images.to(args.device))
preds = classifier(feature).argmax(dim=1)
correct = (preds == labels.to(args.device)).sum().item()
acc_meter.update(correct/preds.shape[0])
print(f'Accuracy = {acc_meter.avg*100:.2f}')




if gpu==0 and (epoch+1) == (args.eval.num_epochs-1):
print('epoch:',epoch+1)
classifier.eval()
correct, total = 0, 0
acc_meter.reset()
if gpu == 0:
for idx, (images, labels) in enumerate(test_loader):
with torch.no_grad():
feature = model(images.to(args.device))
preds = classifier(feature).argmax(dim=1)
correct = (preds == labels.to(args.device)).sum().item()
acc_meter.update(correct/preds.shape[0])
print(f'Accuracy = {acc_meter.avg*100:.2f}')
break

dist.destroy_process_group()
if __name__ == "__main__":
main(args=get_args())
















args = get_args()
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "3367"
args.world_size = args.gpus * args.nodes
mp.spawn(main, args=(args,), nprocs=args.gpus, join=True)
Loading