Skip to content

Commit

Permalink
refactored code, made utils module, bugfix cyclegan
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Dec 8, 2020
1 parent 9b90558 commit 8c9b696
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 128 deletions.
24 changes: 4 additions & 20 deletions cyclegan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,6 @@ def forward(self, x):
return self.resblock(x)


def initialize_weights(model, nonlinearity='leaky_relu'):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(
m.weight,
mode='fan_out',
nonlinearity=nonlinearity
)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.normal_(m.weight, 0.0, 0.02)
nn.init.constant_(m.bias, 0)


def set_requires_grad(nets, requires_grad=False):
if not isinstance(nets, list):
nets = [nets]
Expand Down Expand Up @@ -101,8 +88,7 @@ def __init__(
hidden_dim * 2,
kernel_size=3,
padding=1,
stride=2,
padding_mode='reflect'
stride=2
),
nn.InstanceNorm2d(hidden_dim * 2),
nn.ReLU(),
Expand All @@ -111,8 +97,7 @@ def __init__(
hidden_dim * 4,
kernel_size=3,
padding=1,
stride=2,
padding_mode='reflect'
stride=2
),
nn.InstanceNorm2d(hidden_dim * 4),
nn.ReLU(),
Expand Down Expand Up @@ -194,9 +179,8 @@ def __init__(self, input_channels, hidden_dim):
nn.Conv2d(
hidden_dim * 8,
out_channels=1,
kernel_size=1,
padding=1,
stride=1,
kernel_size=4,
padding=1
)
)
init_weights(self)
Expand Down
40 changes: 33 additions & 7 deletions cyclegan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


parser = argparse.ArgumentParser()
parser.add_argument('-train', action="store_true", default=False, help='Train cycleGAN models')
parser.add_argument('--dim_A', type=int, default=3, help='Numer of channels for class A')
parser.add_argument('--dim_B', type=int, default=3, help='Numer of channels for class B')
parser.add_argument('--n_res_blocks', type=int, default=9, help='Number of ResNet Blocks for generators')
Expand All @@ -25,7 +24,7 @@
parser.add_argument('--load_shape', type=int, default=256, help='Initial image H or W')
parser.add_argument('--target_shape', type=int, default=224, help='Final image H or W')
parser.add_argument('--progress_interval', type=int, default=1, help='Save model and generated image every x epoch')
parser.add_argument('--sample_batches', type=int, default=25, help='How many generated images to sample')
parser.add_argument('--sample_batches', type=int, default=32, help='How many generated images to sample')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--lambda_identity', type=float, default=0.1, help='Identity loss weight')
parser.add_argument('--lambda_cycle', type=float, default=10., help='Cycle loss weight')
Expand Down Expand Up @@ -103,7 +102,14 @@
D_A.train()
D_B.train()
disc_A_losses, gen_A_losses, disc_B_losses, gen_B_losses = [], [], [], []
(gen_ad_A_losses,
gen_ad_B_losses,
gen_id_A_losses,
gen_id_B_losses,
gen_cyc_A_losses,
gen_cyc_B_losses) = [], [], [], [], [], []
real_As, real_Bs, fake_As, fake_Bs = [], [], [], []
identity_As, identity_Bs, cycle_As, cycle_Bs = [], [], [], []
for batch_idx, (real_A, real_B) in enumerate(dataloader):
real_A = torch.nn.functional.interpolate(real_A, size=args.target_shape).to(device)
real_B = torch.nn.functional.interpolate(real_B, size=args.target_shape).to(device)
Expand All @@ -123,9 +129,9 @@
real_B_logits = D_B(real_B)

# sample from queue
pool_fake_A = pool_A.sample(fake_A.detach())
pool_fake_B = pool_B.sample(fake_B.detach())
fake_pool_A_logits = D_B(pool_fake_A)
pool_fake_A = pool_A.sample(fake_A.clone().detach())
pool_fake_B = pool_B.sample(fake_B.clone().detach())
fake_pool_A_logits = D_A(pool_fake_A)
fake_pool_B_logits = D_B(pool_fake_B)

# disc loss
Expand All @@ -135,7 +141,7 @@
disc_B_fake_loss = criterion_GAN(fake_pool_B_logits, torch.zeros_like(fake_pool_B_logits))
disc_B_real_loss = criterion_GAN(real_B_logits, torch.ones_like(real_B_logits))
disc_B_loss = (disc_B_fake_loss + disc_B_real_loss) / 2
disc_loss = disc_A_loss + disc_A_loss
disc_loss = disc_A_loss + disc_B_loss

# generator loss
adversarial_A_loss = criterion_GAN(fake_A_logits, torch.ones_like(fake_A_logits))
Expand Down Expand Up @@ -165,11 +171,21 @@
gen_B_losses.append(gen_B_loss.item())
disc_A_losses.append(disc_A_loss.item())
disc_B_losses.append(disc_B_loss.item())
gen_ad_A_losses.append(adversarial_A_loss.item())
gen_ad_B_losses.append(adversarial_B_loss.item())
gen_id_A_losses.append(identity_A_loss.item())
gen_id_B_losses.append(identity_B_loss.item())
gen_cyc_A_losses.append(cycle_A_loss.item())
gen_cyc_B_losses.append(cycle_B_loss.item())
if batch_idx in sampled_idx:
real_As.append(real_A.detach().cpu())
real_Bs.append(real_B.detach().cpu())
fake_As.append(fake_A.detach().cpu())
fake_Bs.append(fake_B.detach().cpu())
identity_As.append(identity_A.detach().cpu())
identity_Bs.append(identity_B.detach().cpu())
cycle_As.append(cycle_A.detach().cpu())
cycle_Bs.append(cycle_B.detach().cpu())

lr_scheduler_D.step()
lr_scheduler_G.step()
Expand All @@ -179,12 +195,22 @@
'Discriminator A': sum(disc_A_losses) / len(disc_A_losses),
'Discriminator B': sum(disc_B_losses) / len(disc_B_losses),
'Generator A': sum(gen_A_losses) / len(gen_A_losses),
'Generator B': sum(gen_B_losses) / len(gen_B_losses)
'Generator B': sum(gen_B_losses) / len(gen_B_losses),
'Generator Adversarial A': sum(gen_ad_A_losses) / len(gen_ad_A_losses),
'Generator Adversarial B': sum(gen_ad_B_losses) / len(gen_ad_B_losses),
'Generator Cycle A': sum(gen_cyc_A_losses) / len(gen_cyc_A_losses),
'Generator Cycle B': sum(gen_cyc_B_losses) / len(gen_cyc_B_losses),
'Generator Identity A': sum(gen_id_A_losses) / len(gen_id_A_losses),
'Generator Identity B': sum(gen_id_B_losses) / len(gen_id_B_losses)
}, global_step=epoch)
writer.add_image('Fake A', make_images(fake_As), global_step=epoch)
writer.add_image('Fake B', make_images(fake_Bs), global_step=epoch)
writer.add_image('Real A', make_images(real_As), global_step=epoch)
writer.add_image('Real B', make_images(real_Bs), global_step=epoch)
writer.add_image('Identity A', make_images(identity_As), global_step=epoch)
writer.add_image('Identity B', make_images(identity_Bs), global_step=epoch)
writer.add_image('Cycle A', make_images(cycle_As), global_step=epoch)
writer.add_image('Cycle B', make_images(cycle_Bs), global_step=epoch)
if (epoch + 1) % 10 == 0:
if not os.path.isdir(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
Expand Down
52 changes: 27 additions & 25 deletions moco/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,22 @@
from torch.utils.tensorboard import SummaryWriter
from torch.nn import functional as F

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.exceptions import ConvergenceWarning

import argparse
from tqdm import tqdm
from pathlib import Path
from datetime import datetime
from warnings import simplefilter

from moco.model import MoCo
from moco.utils import (GaussianBlur, CIFAR10Pairs, MoCoLoss,
MemoryBank, momentum_update, get_momentum_encoder)
from networks.layers import Linear_Probe
from utils.contrastive import get_feature_label

simplefilter(action='ignore', category=ConvergenceWarning)
parser = argparse.ArgumentParser(description='Train MoCo')

# training configs
parser.add_argument('--lr', default=0.03, type=float, help='initial learning rate')
parser.add_argument('--continue_train', action="store_true", default=False, help='continue training')
parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--batch_size', default=256, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--wd', default=0.0001, type=float, metavar='W', help='weight decay')
Expand Down Expand Up @@ -88,7 +85,16 @@
memo_bank = MemoryBank(f_k, device, momentum_loader, args.K)
writer = SummaryWriter(args.logs_root + f'/{int(datetime.now().timestamp()*1e6)}')

pbar = tqdm(range(args.epochs))
start_epoch = 0
if args.continue_train:
state_dicts = torch.load(args.check_point)
start_epoch = state_dicts['start_epoch']
f_q.load_state_dict(state_dicts['f_q'])
f_k.load_state_dict(state_dicts['f_k'])
optimizer.load_state_dict(state_dicts['optimizer'])
del state_dicts

pbar = tqdm(range(start_epoch, args.epochs))
for epoch in pbar:
train_losses = []
for x1, x2 in train_loader:
Expand All @@ -107,32 +113,28 @@
pbar.set_postfix({'Loss': loss.item(), 'Learning Rate': scheduler.get_last_lr()[0]})

writer.add_scalar('Train Loss', sum(train_losses) / len(train_losses), global_step=epoch)
torch.save(f_q.state_dict(), args.check_point)
scheduler.step()

feature_bank, feature_labels = [], []
for data, target in momentum_loader:
with torch.no_grad():
features = f_q(data)
feature_bank.append(features)
feature_labels.append(target)
feature_bank = torch.cat(feature_bank).cpu().numpy()
feature_labels = torch.cat(feature_labels).numpy()
f_q.eval()
# extract features as training data
feature_bank, feature_labels = get_feature_label(f_q, momentum_loader, device, normalize=True)

linear_classifier = LogisticRegression(multi_class='multinomial', solver='lbfgs')
linear_classifier = Linear_Probe(len(momentum_data.classes), hidden_dim=args.feature_dim).to(device)
linear_classifier.fit(feature_bank, feature_labels)

y_preds, y_trues = [], []
for data, target in test_loader:
with torch.no_grad():
feature = f_q(data).cpu().numpy()
y_preds.extend(linear_classifier.predict(feature).tolist())
y_trues.append(target)
y_trues = torch.cat(y_trues, dim=0).numpy()
top1acc = accuracy_score(y_trues, y_preds) * 100
# using linear classifier to predict test data
y_preds, y_trues = get_feature_label(f_q, test_loader, device, normalize=True, predictor=linear_classifier)
top1acc = y_trues.eq(y_preds).sum().item() / y_preds.size(0)

writer.add_scalar('Top Acc @ 1', top1acc, global_step=epoch)
writer.add_scalar('Representation Standard Deviation', feature_bank.std(), global_step=epoch)
tqdm.write(f'Epoch: {epoch + 1}/{args.epochs}, \
Training Loss: {sum(train_losses) / len(train_losses)}, \
Top Accuracy @ 1: {top1acc}, \
Representation STD: {feature_bank.std()}')
torch.save({
'f_q': f_q.state_dict(),
'f_k': f_k.state_dict(),
'optimizer': optimizer.state_dict(),
'start_epoch': epoch + 1},
args.check_point)
40 changes: 40 additions & 0 deletions networks/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from torch import nn
import torch
from torch.utils.data import DataLoader
from networks.utils import SimpleDataset


class ConvNormAct(nn.Module):
Expand Down Expand Up @@ -90,3 +93,40 @@ def __init__(self, *args):

def forward(self, x):
return x.view(self.shape)


class Linear_Probe(nn.Module):
def __init__(self, num_classes, hidden_dim=256, lr=1e-3):
super().__init__()
self.fc = nn.Linear(hidden_dim, num_classes)
self.optimizer = torch.optim.SGD(self.parameters(), lr=lr,
momentum=0.9, weight_decay=0.0001)
self.criterion = nn.CrossEntropyLoss()
self.scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
self.optimizer,
lr_lambda=lambda lr: 0.995)

def forward(self, x):
return self.fc(x)

def loss(self, y_hat, y):
return self.criterion(y_hat, y)

def fit(self, x, y, epochs=500):
dataset = SimpleDataset(x, y)
loader = DataLoader(dataset, batch_size=2056, shuffle=True)
self.train()
for _ in range(epochs):
for features, labels in loader:
y_hat = self.forward(features)
loss = self.loss(y_hat, labels)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.scheduler.step()

def predict(self, x):
self.eval()
with torch.no_grad():
predictions = self.forward(x)
return torch.argmax(predictions, dim=1)
14 changes: 14 additions & 0 deletions networks/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torch import nn
from torch.utils.data import Dataset


def initialize_modules(model, nonlinearity='leaky_relu'):
Expand All @@ -12,3 +13,16 @@ def initialize_modules(model, nonlinearity='leaky_relu'):
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.Linear)):
nn.init.normal_(m.weight, 0.0, 0.02)
nn.init.constant_(m.bias, 0)


class SimpleDataset(Dataset):
def __init__(self, x, y):
super().__init__()
self.x = x
self.y = y

def __getitem__(self, idx):
return self.x[idx], self.y[idx]

def __len__(self):
return self.x.size(0)
13 changes: 0 additions & 13 deletions simsiam/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,3 @@ def __call__(self, x):
sigma = random.uniform(self.sigma[0], self.sigma[1])
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
return x


class SimpleDataset(Dataset):
def __init__(self, x, y):
super().__init__()
self.x = x
self.y = y

def __getitem__(self, idx):
return self.x[idx], self.y[idx]

def __len__(self):
return self.x.size(0)
33 changes: 9 additions & 24 deletions simsiam/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from pathlib import Path
from datetime import datetime

from simsiam.model import SimSiam, Linear_Classifier
from simsiam.model import SimSiam
from simsiam.data import GaussianBlur, CIFAR10Pairs
from networks.layers import Linear_Probe
from utils.contrastive import get_feature_label

parser = argparse.ArgumentParser(description='Train SimSiam')

Expand Down Expand Up @@ -104,31 +106,14 @@
writer.add_scalar('Train Loss', sum(train_losses) / len(train_losses), global_step=epoch)

model.eval()
feature_bank, targets = [], []
# get current feature maps & fit LR
for data, target in feature_loader:
data, target = data.to(device), target.to(device)
with torch.no_grad():
feature = model(data)
feature = F.normalize(feature, dim=1)
feature_bank.append(feature.clone().detach())
targets.append(target)
feature_bank = torch.cat(feature_bank, dim=0)
feature_labels = torch.cat(targets, dim=0)

linear_classifier = Linear_Classifier(args, len(feature_data.classes)).to(device)
# extract features as training data
feature_bank, feature_labels = get_feature_label(model, feature_loader, device, normalize=True)

linear_classifier = Linear_Probe(len(feature_data.classes), hidden_dim=args.hidden_dim).to(device)
linear_classifier.fit(feature_bank, feature_labels)

y_preds, y_trues = [], []
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with torch.no_grad():
feature = model(data)
feature = F.normalize(feature, dim=1)
y_preds.append(linear_classifier.predict(feature.detach()))
y_trues.append(target)
y_trues = torch.cat(y_trues, dim=0)
y_preds = torch.cat(y_preds, dim=0)
# using linear classifier to predict test data
y_preds, y_trues = get_feature_label(model, test_loader, device, normalize=True, predictor=linear_classifier)
top1acc = y_trues.eq(y_preds).sum().item() / y_preds.size(0)
writer.add_scalar('Top Acc @ 1', top1acc, global_step=epoch)
writer.add_scalar('Representation Standard Deviation', feature_bank.std(), global_step=epoch)
Expand Down
Loading

0 comments on commit 8c9b696

Please sign in to comment.