Skip to content

Commit

Permalink
modified path creation and simsiam loss fn
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Dec 4, 2020
1 parent 1341dbb commit b097757
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 28 deletions.
2 changes: 1 addition & 1 deletion moco/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
test_data = CIFAR10(root=args.data_root, train=False, transform=test_transform, download=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=28)

Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True)
Path('/'.join(args.check_point.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
Path(args.logs_root).mkdir(parents=True, exist_ok=True)

f_q = torch.nn.DataParallel(MoCo(args), device_ids=[0, 1]).to(device)
Expand Down
1 change: 1 addition & 0 deletions simsiam/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


class CIFAR10Pairs(CIFAR10):
"""Outputs two versions of same image through two different transforms"""
def __getitem__(self, index):
img = self.data[index]
img = Image.fromarray(img)
Expand Down
20 changes: 9 additions & 11 deletions simsiam/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum for optimizer')
parser.add_argument('--symmetric', action="store_true", default=True, help='loss function is symmetric')
parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices')

# simsiam model configs
parser.add_argument('-a', '--backbone', default='resnet18')
Expand All @@ -42,12 +44,6 @@

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def cosine_loss(p, z):
z = z.detach()
p = F.normalize(p, dim=1)
z = F.normalize(z, dim=1)
return -(p @ z.T).mean()


if __name__ == '__main__':
"""https://github.com/facebookresearch/moco"""
Expand Down Expand Up @@ -77,8 +73,8 @@ def cosine_loss(p, z):
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=28)

writer = SummaryWriter(args.logs_root)
model = SimSiam(args).to(device)
Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True)
model = torch.nn.DataParallel(SimSiam(args), device_ids=args.device_ids).to(device)
Path('/'.join(args.check_point.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
Path(args.logs_root).mkdir(parents=True, exist_ok=True)

optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
Expand All @@ -92,8 +88,10 @@ def cosine_loss(p, z):
for x1, x2 in train_loader:
x1, x2 = x1.to(device), x2.to(device)
z1, z2, p1, p2 = model(x1, x2)
# symmetric loss
loss = (cosine_loss(p1, z2) + cosine_loss(p2, z1)) / 2
if args.symmetric:
loss = (model.module.cosine_loss(p1, z2) + model.module.cosine_loss(p2, z1)) / 2
else:
loss = model.module.cosine_loss(p1, z2)
train_losses.append(loss.item())
optimizer.zero_grad()
loss.backward()
Expand Down Expand Up @@ -134,7 +132,7 @@ def cosine_loss(p, z):
f'Epoch {epoch + 1}/{args.epochs}, \
Train Loss: {sum(train_losses) / len(train_losses):.3f}, \
Top Acc @ 1: {top1acc:.3f}, \
Learning Rate: {scheduler.get_last_lr()}'
Learning Rate: {scheduler.get_last_lr()[0]}'
)
torch.save(model.state_dict(), args.check_point)
scheduler.step()
7 changes: 7 additions & 0 deletions simsiam/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""https://github.com/facebookresearch/moco"""

import torch
from torch import nn
from torch.nn import functional as F
import torchvision


Expand Down Expand Up @@ -49,3 +51,8 @@ def forward(self, x1, x2=None, istrain=True):
return z1, z2, p1, p2
else:
return self.encoder(x1)

def cosine_loss(self, p, z):
p = F.normalize(p, dim=1)
z = F.normalize(z, dim=1).detach()
return -torch.einsum('ij,ij->i', p, z).mean()
40 changes: 31 additions & 9 deletions vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,40 @@


parser = argparse.ArgumentParser()
# training

# image settings
parser.add_argument('--img_channels', type=int, default=3, help='Numer of channels for images')
parser.add_argument('--model_dim', type=float, default=128, help='model dimensions multiplier')
parser.add_argument('--z_dim', type=float, default=100, help='dimension of random noise latent vector')
parser.add_argument('--img_size', type=int, default=64, help='H, W of the input images')
parser.add_argument('--crop_size', type=int, default=128, help='H, W of the input images')

# model params
parser.add_argument('--z_dim', type=float, default=100, help='dimension of random noise latent vector')
parser.add_argument('--n_res_blocks', type=int, default=1, help='Number of ResNet Blocks for generators')
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate for generators')
parser.add_argument('--betas', type=tuple, default=(0.5, 0.999), help='Betas for Adam optimizer')
parser.add_argument('--model_dim', type=float, default=128, help='model dimensions multiplier')

# loss fn
parser.add_argument('--beta', type=float, default=1., help='Beta hyperparam for KLD Loss')
parser.add_argument('--recon', type=str, default='bce', help='Reconstruction loss type [bce, l2]')

# training hyperparams
parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices')
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate for generators')
parser.add_argument('--betas', type=tuple, default=(0.5, 0.999), help='Betas for Adam optimizer')
parser.add_argument('--n_epochs', type=int, default=200, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
parser.add_argument('--sample_size', type=int, default=64, help='Size of sampled images')

# logging
parser.add_argument('--log_dir', type=str, default='vae/logs', help='Path to where log files will be saved')
parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located')
parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices')
parser.add_argument('--sample_path', type=str, default='vae/samples', help='Path to where samples are saved')
parser.add_argument('--img_ext', type=str, default='.jpg', help='Image extentions')
parser.add_argument('--checkpoint_dir', type=str, default='vae/model_weights', help='Path to where model weights will be saved')

# for sampler
# for sampling
parser.add_argument('--sample_size', type=int, default=64, help='Size of sampled images')
parser.add_argument('--sample', action="store_true", default=False, help='Sample from VAE')
parser.add_argument('--walk', action="store_true", default=False, help='Walk through a feature & sample')

args = parser.parse_args()


Expand All @@ -48,19 +59,27 @@
loader = get_loaders(args)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# initialize model, instantiate opt & scheduler & loss fn
model = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(device)
model.apply(initialize_modules)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas)
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.995)
fixed_z = torch.randn(args.sample_size, args.z_dim).to(device)
criterion = VAELoss(args)

# fixed z to see how model changes on the same latent vectors
fixed_z = torch.randn(args.sample_size, args.z_dim).to(device)

pbar = tqdm(range(args.n_epochs))
for epoch in pbar:
losses, kdls, rls = [], [], []
model.train()
for img in loader:
x = img.to(device)

# x_hat for recon loss, mu & logvar for kdl loss
x_hat, mu, logvar = model(x)

# return kdl & recon loss for logging purposes
loss, recon_loss, kld_loss = criterion(x, x_hat, mu, logvar)
losses.append(loss.item())
kdls.append(kld_loss.item())
Expand All @@ -77,9 +96,12 @@
writer.add_scalar('KLD Loss', sum(kdls) / len(kdls), global_step=epoch)
writer.add_scalar('Reconstruction Loss', sum(rls) / len(rls), global_step=epoch)

# decode fixed z latent vectors
model.eval()
with torch.no_grad():
sampled_images = model.module.sample(fixed_z)

# log images and losses & save model parameters
writer.add_image('Fixed Generated Images', torchvision.utils.make_grid(sampled_images), global_step=epoch)
writer.add_image('Reconstructed Images', torchvision.utils.make_grid(x_hat.detach()), global_step=epoch)
writer.add_image('Original Images', torchvision.utils.make_grid(x.detach()), global_step=epoch)
Expand Down
11 changes: 4 additions & 7 deletions vae/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,19 @@


class Sampler:
def __init__(self, sample_path='vae/samples', ext='.jpg'):
def __init__(self):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.vae = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(self.device)
self.vae.load_state_dict(torch.load(f"{args.checkpoint_dir}/VAE.pth")['model'])
self.vae.eval()
Path(sample_path).mkdir(parents=True, exist_ok=True)

self.sample_path = sample_path
self.ext = ext
Path(args.sample_path).mkdir(parents=True, exist_ok=True)

def sample(self):
with torch.no_grad():
samples = self.vae.module.sample(num_samples=args.sample_size)
torchvision.utils.save_image(
samples,
self.sample_path + f'/sample_{int(datetime.now().timestamp()*1e6)}' + self.ext)
args.sample_path + f'/sample_{int(datetime.now().timestamp()*1e6)}' + args.img_ext)

def generate_walk_z(self):
z = torch.randn(args.z_dim, device=self.device)
Expand All @@ -40,7 +37,7 @@ def walk(self):
samples = self.vae.module.sample(z=z)
torchvision.utils.save_image(
samples,
self.sample_path + f'/walk_{int(datetime.now().timestamp()*1e6)}' + self.ext)
args.sample_path + f'/walk_{int(datetime.now().timestamp()*1e6)}' + args.img_ext)


if __name__ == '__main__':
Expand Down

0 comments on commit b097757

Please sign in to comment.