From 8c9b69645577659e78efaeecff7ab21bc78d3536 Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Tue, 8 Dec 2020 13:42:07 +0800 Subject: [PATCH] refactored code, made utils module, bugfix cyclegan --- cyclegan/models.py | 24 ++++---------------- cyclegan/train.py | 40 ++++++++++++++++++++++++++++------ moco/main.py | 52 +++++++++++++++++++++++--------------------- networks/layers.py | 40 ++++++++++++++++++++++++++++++++++ networks/utils.py | 14 ++++++++++++ simsiam/data.py | 13 ----------- simsiam/main.py | 33 ++++++++-------------------- simsiam/model.py | 39 --------------------------------- utils/contrastive.py | 20 +++++++++++++++++ 9 files changed, 147 insertions(+), 128 deletions(-) create mode 100644 utils/contrastive.py diff --git a/cyclegan/models.py b/cyclegan/models.py index 5f2f29e..5be33ee 100644 --- a/cyclegan/models.py +++ b/cyclegan/models.py @@ -33,19 +33,6 @@ def forward(self, x): return self.resblock(x) -def initialize_weights(model, nonlinearity='leaky_relu'): - for m in model.modules(): - if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - nn.init.kaiming_normal_( - m.weight, - mode='fan_out', - nonlinearity=nonlinearity - ) - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.normal_(m.weight, 0.0, 0.02) - nn.init.constant_(m.bias, 0) - - def set_requires_grad(nets, requires_grad=False): if not isinstance(nets, list): nets = [nets] @@ -101,8 +88,7 @@ def __init__( hidden_dim * 2, kernel_size=3, padding=1, - stride=2, - padding_mode='reflect' + stride=2 ), nn.InstanceNorm2d(hidden_dim * 2), nn.ReLU(), @@ -111,8 +97,7 @@ def __init__( hidden_dim * 4, kernel_size=3, padding=1, - stride=2, - padding_mode='reflect' + stride=2 ), nn.InstanceNorm2d(hidden_dim * 4), nn.ReLU(), @@ -194,9 +179,8 @@ def __init__(self, input_channels, hidden_dim): nn.Conv2d( hidden_dim * 8, out_channels=1, - kernel_size=1, - padding=1, - stride=1, + kernel_size=4, + padding=1 ) ) init_weights(self) diff --git a/cyclegan/train.py b/cyclegan/train.py index 4b7268f..e8c4cdf 100644 --- a/cyclegan/train.py +++ b/cyclegan/train.py @@ -11,7 +11,6 @@ parser = argparse.ArgumentParser() -parser.add_argument('-train', action="store_true", default=False, help='Train cycleGAN models') parser.add_argument('--dim_A', type=int, default=3, help='Numer of channels for class A') parser.add_argument('--dim_B', type=int, default=3, help='Numer of channels for class B') parser.add_argument('--n_res_blocks', type=int, default=9, help='Number of ResNet Blocks for generators') @@ -25,7 +24,7 @@ parser.add_argument('--load_shape', type=int, default=256, help='Initial image H or W') parser.add_argument('--target_shape', type=int, default=224, help='Final image H or W') parser.add_argument('--progress_interval', type=int, default=1, help='Save model and generated image every x epoch') -parser.add_argument('--sample_batches', type=int, default=25, help='How many generated images to sample') +parser.add_argument('--sample_batches', type=int, default=32, help='How many generated images to sample') parser.add_argument('--batch_size', type=int, default=1, help='Batch size') parser.add_argument('--lambda_identity', type=float, default=0.1, help='Identity loss weight') parser.add_argument('--lambda_cycle', type=float, default=10., help='Cycle loss weight') @@ -103,7 +102,14 @@ D_A.train() D_B.train() disc_A_losses, gen_A_losses, disc_B_losses, gen_B_losses = [], [], [], [] + (gen_ad_A_losses, + gen_ad_B_losses, + gen_id_A_losses, + gen_id_B_losses, + gen_cyc_A_losses, + gen_cyc_B_losses) = [], [], [], [], [], [] real_As, real_Bs, fake_As, fake_Bs = [], [], [], [] + identity_As, identity_Bs, cycle_As, cycle_Bs = [], [], [], [] for batch_idx, (real_A, real_B) in enumerate(dataloader): real_A = torch.nn.functional.interpolate(real_A, size=args.target_shape).to(device) real_B = torch.nn.functional.interpolate(real_B, size=args.target_shape).to(device) @@ -123,9 +129,9 @@ real_B_logits = D_B(real_B) # sample from queue - pool_fake_A = pool_A.sample(fake_A.detach()) - pool_fake_B = pool_B.sample(fake_B.detach()) - fake_pool_A_logits = D_B(pool_fake_A) + pool_fake_A = pool_A.sample(fake_A.clone().detach()) + pool_fake_B = pool_B.sample(fake_B.clone().detach()) + fake_pool_A_logits = D_A(pool_fake_A) fake_pool_B_logits = D_B(pool_fake_B) # disc loss @@ -135,7 +141,7 @@ disc_B_fake_loss = criterion_GAN(fake_pool_B_logits, torch.zeros_like(fake_pool_B_logits)) disc_B_real_loss = criterion_GAN(real_B_logits, torch.ones_like(real_B_logits)) disc_B_loss = (disc_B_fake_loss + disc_B_real_loss) / 2 - disc_loss = disc_A_loss + disc_A_loss + disc_loss = disc_A_loss + disc_B_loss # generator loss adversarial_A_loss = criterion_GAN(fake_A_logits, torch.ones_like(fake_A_logits)) @@ -165,11 +171,21 @@ gen_B_losses.append(gen_B_loss.item()) disc_A_losses.append(disc_A_loss.item()) disc_B_losses.append(disc_B_loss.item()) + gen_ad_A_losses.append(adversarial_A_loss.item()) + gen_ad_B_losses.append(adversarial_B_loss.item()) + gen_id_A_losses.append(identity_A_loss.item()) + gen_id_B_losses.append(identity_B_loss.item()) + gen_cyc_A_losses.append(cycle_A_loss.item()) + gen_cyc_B_losses.append(cycle_B_loss.item()) if batch_idx in sampled_idx: real_As.append(real_A.detach().cpu()) real_Bs.append(real_B.detach().cpu()) fake_As.append(fake_A.detach().cpu()) fake_Bs.append(fake_B.detach().cpu()) + identity_As.append(identity_A.detach().cpu()) + identity_Bs.append(identity_B.detach().cpu()) + cycle_As.append(cycle_A.detach().cpu()) + cycle_Bs.append(cycle_B.detach().cpu()) lr_scheduler_D.step() lr_scheduler_G.step() @@ -179,12 +195,22 @@ 'Discriminator A': sum(disc_A_losses) / len(disc_A_losses), 'Discriminator B': sum(disc_B_losses) / len(disc_B_losses), 'Generator A': sum(gen_A_losses) / len(gen_A_losses), - 'Generator B': sum(gen_B_losses) / len(gen_B_losses) + 'Generator B': sum(gen_B_losses) / len(gen_B_losses), + 'Generator Adversarial A': sum(gen_ad_A_losses) / len(gen_ad_A_losses), + 'Generator Adversarial B': sum(gen_ad_B_losses) / len(gen_ad_B_losses), + 'Generator Cycle A': sum(gen_cyc_A_losses) / len(gen_cyc_A_losses), + 'Generator Cycle B': sum(gen_cyc_B_losses) / len(gen_cyc_B_losses), + 'Generator Identity A': sum(gen_id_A_losses) / len(gen_id_A_losses), + 'Generator Identity B': sum(gen_id_B_losses) / len(gen_id_B_losses) }, global_step=epoch) writer.add_image('Fake A', make_images(fake_As), global_step=epoch) writer.add_image('Fake B', make_images(fake_Bs), global_step=epoch) writer.add_image('Real A', make_images(real_As), global_step=epoch) writer.add_image('Real B', make_images(real_Bs), global_step=epoch) + writer.add_image('Identity A', make_images(identity_As), global_step=epoch) + writer.add_image('Identity B', make_images(identity_Bs), global_step=epoch) + writer.add_image('Cycle A', make_images(cycle_As), global_step=epoch) + writer.add_image('Cycle B', make_images(cycle_Bs), global_step=epoch) if (epoch + 1) % 10 == 0: if not os.path.isdir(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) diff --git a/moco/main.py b/moco/main.py index c8ed3f4..02192bd 100644 --- a/moco/main.py +++ b/moco/main.py @@ -5,25 +5,22 @@ from torch.utils.tensorboard import SummaryWriter from torch.nn import functional as F -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import accuracy_score -from sklearn.exceptions import ConvergenceWarning - import argparse from tqdm import tqdm from pathlib import Path from datetime import datetime -from warnings import simplefilter from moco.model import MoCo from moco.utils import (GaussianBlur, CIFAR10Pairs, MoCoLoss, MemoryBank, momentum_update, get_momentum_encoder) +from networks.layers import Linear_Probe +from utils.contrastive import get_feature_label -simplefilter(action='ignore', category=ConvergenceWarning) parser = argparse.ArgumentParser(description='Train MoCo') # training configs parser.add_argument('--lr', default=0.03, type=float, help='initial learning rate') +parser.add_argument('--continue_train', action="store_true", default=False, help='continue training') parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--batch_size', default=256, type=int, metavar='N', help='mini-batch size') parser.add_argument('--wd', default=0.0001, type=float, metavar='W', help='weight decay') @@ -88,7 +85,16 @@ memo_bank = MemoryBank(f_k, device, momentum_loader, args.K) writer = SummaryWriter(args.logs_root + f'/{int(datetime.now().timestamp()*1e6)}') - pbar = tqdm(range(args.epochs)) + start_epoch = 0 + if args.continue_train: + state_dicts = torch.load(args.check_point) + start_epoch = state_dicts['start_epoch'] + f_q.load_state_dict(state_dicts['f_q']) + f_k.load_state_dict(state_dicts['f_k']) + optimizer.load_state_dict(state_dicts['optimizer']) + del state_dicts + + pbar = tqdm(range(start_epoch, args.epochs)) for epoch in pbar: train_losses = [] for x1, x2 in train_loader: @@ -107,32 +113,28 @@ pbar.set_postfix({'Loss': loss.item(), 'Learning Rate': scheduler.get_last_lr()[0]}) writer.add_scalar('Train Loss', sum(train_losses) / len(train_losses), global_step=epoch) - torch.save(f_q.state_dict(), args.check_point) scheduler.step() - feature_bank, feature_labels = [], [] - for data, target in momentum_loader: - with torch.no_grad(): - features = f_q(data) - feature_bank.append(features) - feature_labels.append(target) - feature_bank = torch.cat(feature_bank).cpu().numpy() - feature_labels = torch.cat(feature_labels).numpy() + f_q.eval() + # extract features as training data + feature_bank, feature_labels = get_feature_label(f_q, momentum_loader, device, normalize=True) - linear_classifier = LogisticRegression(multi_class='multinomial', solver='lbfgs') + linear_classifier = Linear_Probe(len(momentum_data.classes), hidden_dim=args.feature_dim).to(device) linear_classifier.fit(feature_bank, feature_labels) - y_preds, y_trues = [], [] - for data, target in test_loader: - with torch.no_grad(): - feature = f_q(data).cpu().numpy() - y_preds.extend(linear_classifier.predict(feature).tolist()) - y_trues.append(target) - y_trues = torch.cat(y_trues, dim=0).numpy() - top1acc = accuracy_score(y_trues, y_preds) * 100 + # using linear classifier to predict test data + y_preds, y_trues = get_feature_label(f_q, test_loader, device, normalize=True, predictor=linear_classifier) + top1acc = y_trues.eq(y_preds).sum().item() / y_preds.size(0) + writer.add_scalar('Top Acc @ 1', top1acc, global_step=epoch) writer.add_scalar('Representation Standard Deviation', feature_bank.std(), global_step=epoch) tqdm.write(f'Epoch: {epoch + 1}/{args.epochs}, \ Training Loss: {sum(train_losses) / len(train_losses)}, \ Top Accuracy @ 1: {top1acc}, \ Representation STD: {feature_bank.std()}') + torch.save({ + 'f_q': f_q.state_dict(), + 'f_k': f_k.state_dict(), + 'optimizer': optimizer.state_dict(), + 'start_epoch': epoch + 1}, + args.check_point) diff --git a/networks/layers.py b/networks/layers.py index fbd03c2..b1e8ab8 100644 --- a/networks/layers.py +++ b/networks/layers.py @@ -1,4 +1,7 @@ from torch import nn +import torch +from torch.utils.data import DataLoader +from networks.utils import SimpleDataset class ConvNormAct(nn.Module): @@ -90,3 +93,40 @@ def __init__(self, *args): def forward(self, x): return x.view(self.shape) + + +class Linear_Probe(nn.Module): + def __init__(self, num_classes, hidden_dim=256, lr=1e-3): + super().__init__() + self.fc = nn.Linear(hidden_dim, num_classes) + self.optimizer = torch.optim.SGD(self.parameters(), lr=lr, + momentum=0.9, weight_decay=0.0001) + self.criterion = nn.CrossEntropyLoss() + self.scheduler = torch.optim.lr_scheduler.MultiplicativeLR( + self.optimizer, + lr_lambda=lambda lr: 0.995) + + def forward(self, x): + return self.fc(x) + + def loss(self, y_hat, y): + return self.criterion(y_hat, y) + + def fit(self, x, y, epochs=500): + dataset = SimpleDataset(x, y) + loader = DataLoader(dataset, batch_size=2056, shuffle=True) + self.train() + for _ in range(epochs): + for features, labels in loader: + y_hat = self.forward(features) + loss = self.loss(y_hat, labels) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + self.scheduler.step() + + def predict(self, x): + self.eval() + with torch.no_grad(): + predictions = self.forward(x) + return torch.argmax(predictions, dim=1) diff --git a/networks/utils.py b/networks/utils.py index f56c016..1375169 100644 --- a/networks/utils.py +++ b/networks/utils.py @@ -1,4 +1,5 @@ from torch import nn +from torch.utils.data import Dataset def initialize_modules(model, nonlinearity='leaky_relu'): @@ -12,3 +13,16 @@ def initialize_modules(model, nonlinearity='leaky_relu'): elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.Linear)): nn.init.normal_(m.weight, 0.0, 0.02) nn.init.constant_(m.bias, 0) + + +class SimpleDataset(Dataset): + def __init__(self, x, y): + super().__init__() + self.x = x + self.y = y + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + def __len__(self): + return self.x.size(0) diff --git a/simsiam/data.py b/simsiam/data.py index 9e6369a..e5f5a5c 100644 --- a/simsiam/data.py +++ b/simsiam/data.py @@ -31,16 +31,3 @@ def __call__(self, x): sigma = random.uniform(self.sigma[0], self.sigma[1]) x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) return x - - -class SimpleDataset(Dataset): - def __init__(self, x, y): - super().__init__() - self.x = x - self.y = y - - def __getitem__(self, idx): - return self.x[idx], self.y[idx] - - def __len__(self): - return self.x.size(0) diff --git a/simsiam/main.py b/simsiam/main.py index 3514cf0..40ea7b1 100644 --- a/simsiam/main.py +++ b/simsiam/main.py @@ -10,8 +10,10 @@ from pathlib import Path from datetime import datetime -from simsiam.model import SimSiam, Linear_Classifier +from simsiam.model import SimSiam from simsiam.data import GaussianBlur, CIFAR10Pairs +from networks.layers import Linear_Probe +from utils.contrastive import get_feature_label parser = argparse.ArgumentParser(description='Train SimSiam') @@ -104,31 +106,14 @@ writer.add_scalar('Train Loss', sum(train_losses) / len(train_losses), global_step=epoch) model.eval() - feature_bank, targets = [], [] - # get current feature maps & fit LR - for data, target in feature_loader: - data, target = data.to(device), target.to(device) - with torch.no_grad(): - feature = model(data) - feature = F.normalize(feature, dim=1) - feature_bank.append(feature.clone().detach()) - targets.append(target) - feature_bank = torch.cat(feature_bank, dim=0) - feature_labels = torch.cat(targets, dim=0) - - linear_classifier = Linear_Classifier(args, len(feature_data.classes)).to(device) + # extract features as training data + feature_bank, feature_labels = get_feature_label(model, feature_loader, device, normalize=True) + + linear_classifier = Linear_Probe(len(feature_data.classes), hidden_dim=args.hidden_dim).to(device) linear_classifier.fit(feature_bank, feature_labels) - y_preds, y_trues = [], [] - for data, target in test_loader: - data, target = data.to(device), target.to(device) - with torch.no_grad(): - feature = model(data) - feature = F.normalize(feature, dim=1) - y_preds.append(linear_classifier.predict(feature.detach())) - y_trues.append(target) - y_trues = torch.cat(y_trues, dim=0) - y_preds = torch.cat(y_preds, dim=0) + # using linear classifier to predict test data + y_preds, y_trues = get_feature_label(model, test_loader, device, normalize=True, predictor=linear_classifier) top1acc = y_trues.eq(y_preds).sum().item() / y_preds.size(0) writer.add_scalar('Top Acc @ 1', top1acc, global_step=epoch) writer.add_scalar('Representation Standard Deviation', feature_bank.std(), global_step=epoch) diff --git a/simsiam/model.py b/simsiam/model.py index 1ff69ab..1245181 100644 --- a/simsiam/model.py +++ b/simsiam/model.py @@ -2,45 +2,6 @@ from torch import nn from torch.nn import functional as F import torchvision -from torch.utils.data import DataLoader -from simsiam.data import SimpleDataset - - -class Linear_Classifier(nn.Module): - def __init__(self, args, num_classes, epochs=500, lr=1e-3): - super().__init__() - self.fc = nn.Linear(args.hidden_dim, num_classes) - self.epochs = epochs - self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) - self.criterion = nn.CrossEntropyLoss() - self.scheduler = torch.optim.lr_scheduler.MultiplicativeLR( - self.optimizer, - lr_lambda=lambda lr: 0.995) - - def forward(self, x): - return self.fc(x) - - def loss(self, y_hat, y): - return self.criterion(y_hat, y) - - def fit(self, x, y): - dataset = SimpleDataset(x, y) - loader = DataLoader(dataset, batch_size=2056, shuffle=True) - self.train() - for _ in range(self.epochs): - for features, labels in loader: - y_hat = self.forward(features) - loss = self.loss(y_hat, labels) - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - self.scheduler.step() - - def predict(self, x): - self.eval() - with torch.no_grad(): - predictions = self.forward(x) - return torch.argmax(predictions, dim=1) class SimSiam(nn.Module): diff --git a/utils/contrastive.py b/utils/contrastive.py new file mode 100644 index 0000000..91beaed --- /dev/null +++ b/utils/contrastive.py @@ -0,0 +1,20 @@ +import torch +from torch.nn import functional as F + + +def get_feature_label(feature_extractor, feature_loader, device, normalize=True, predictor=None): + transformed_features, targets = [], [] + for data, target in feature_loader: + data, target = data.to(device), target.to(device) + with torch.no_grad(): + feature = feature_extractor(data) + if normalize: + feature = F.normalize(feature, dim=1) + if predictor is None: + transformed_features.append(feature.clone()) + else: + transformed_features.append(predictor.predict(feature.clone())) + targets.append(target) + transformed_features = torch.cat(transformed_features, dim=0) + targets = torch.cat(targets, dim=0) + return transformed_features, targets