Skip to content

Commit

Permalink
vae bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Dec 3, 2020
1 parent 40967e1 commit 3481421
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 21 deletions.
2 changes: 1 addition & 1 deletion moco/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=28)

Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True)
Path(args.logs_root.split('/')[1]).mkdir(parents=True, exist_ok=True)
Path(args.logs_root).mkdir(parents=True, exist_ok=True)

f_q = torch.nn.DataParallel(MoCo(args), device_ids=[0, 1]).to(device)
f_k = get_momentum_encoder(f_q)
Expand Down
2 changes: 1 addition & 1 deletion networks/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, in_channels, activation, normalization):
padding=1,
),
norm,
activation,
act,
nn.Conv2d(
in_channels,
in_channels,
Expand Down
2 changes: 1 addition & 1 deletion simsiam/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def cosine_loss(p, z):
writer = SummaryWriter(args.logs_root)
model = SimSiam(args).to(device)
Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True)
Path(args.logs_root.split('/')[1]).mkdir(parents=True, exist_ok=True)
Path(args.logs_root).mkdir(parents=True, exist_ok=True)

optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.wd)
Expand Down
8 changes: 4 additions & 4 deletions vae/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@


class VAELoss(nn.Module):
def __init__(self, args):
def __init__(self, recon=None):
super().__init__()
if args.recon == 'l1':
self.recon = nn.L1Loss()
elif args.recon == 'l2':
if recon == 'l2':
self.recon = nn.MSELoss()
else:
self.recon = nn.L1Loss()

def _KL_Loss(self, mu, logvar):
return torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=1), dim=0)
Expand Down
11 changes: 6 additions & 5 deletions vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,26 @@
parser.add_argument('--n_res_blocks', type=int, default=9, help='Number of ResNet Blocks for generators')
parser.add_argument('--lr', type=float, default=0.0002, help='Learning rate for generators')
parser.add_argument('--betas', type=tuple, default=(0.5, 0.999), help='Betas for Adam optimizer')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs')
parser.add_argument('--n_epochs', type=int, default=200, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
parser.add_argument('--sample_size', type=int, default=32, help='Size of sampled images')
parser.add_argument('--log_dir', type=str, default='vae/logs', help='Path to where log files will be saved')
parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located')
parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices')
parser.add_argument('--img_ext', type=str, default='.jpg', help='Image extentions')
parser.add_argument('--checkpoint_dir', type=str, default='vae/model_weights', help='Path to where model weights will be saved')
args = parser.parse_args()


if __name__ == '__main__':
writer = SummaryWriter(args.log_dir)
Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True)
Path(args.log_dir.split('/')[1]).mkdir(parents=True, exist_ok=True)
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
Path(args.log_dir).mkdir(parents=True, exist_ok=True)

loader = get_loaders(args)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = VAE(args).to(device)
model = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas)
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.95)
fixed_z = torch.randn(args.sample_size, args.z_dim).to(device)
Expand Down Expand Up @@ -70,4 +71,4 @@
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}, f"{opt.checkpoint_dir}/VAE.pth")
}, f"{args.checkpoint_dir}/VAE.pth")
23 changes: 14 additions & 9 deletions vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class VAE(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.encoder = nn.Sequential(
ConvNormAct(args.img_channels, args.model_dim, 'down'),
Expand All @@ -16,16 +17,11 @@ def __init__(self, args):
nn.Flatten(),
nn.Linear((args.img_size // (2**4))**2 * args.model_dim * 8, args.z_dim * 2)
)
self.decoder = nn.Sequential(
self.projector = nn.Sequential(
nn.Linear(args.z_dim, (args.img_size // (2**4))**2 * args.model_dim * 8),
nn.BatchNorm1d((args.img_size // (2**4))**2 * args.model_dim * 8),
nn.ReLU(),
Reshape(
args.batch_size,
args.model_dim * 8,
args.img_size // (2**4),
args.img_size // (2**4)
),
nn.ReLU())
self.decoder = nn.Sequential(
ConvNormAct(args.model_dim * 8, args.model_dim * 4, 'up'),
ConvNormAct(args.model_dim * 4, args.model_dim * 2, 'up'),
ConvNormAct(args.model_dim * 2, args.model_dim, 'up'),
Expand All @@ -35,17 +31,26 @@ def __init__(self, args):

def reparameterize(self, mu, logvar):
batch_size = mu.size(0)
z = torch.randn(batch_size, self.args.z_dim) * torch.sqrt(torch.exp(logvar)) + mu
z = torch.randn(
batch_size,
self.args.z_dim,
device=mu.device) * torch.sqrt(torch.exp(logvar)) + mu
return z

def forward(self, x):
batch_size = x.size(0)
z = self.encoder(x)
z = z.view(-1, 2, self.args.z_dim)
mu, logvar = z[:, 0, :], z[:, 1, :]
z = self.reparameterize(mu, logvar)
z = self.projector(z).view(batch_size, self.args.model_dim * 8,
self.args.img_size // (2**4), self.args.img_size // (2**4))
return self.decoder(z), mu, logvar

def sample(self, z=None, num_samples=50):
if z is None:
z = torch.randn(num_samples, self.args.z_dim)
num_samples = z.size(0)
z = self.projector(z).view(num_samples, self.args.model_dim * 8,
self.args.img_size // (2**4), self.args.img_size // (2**4))
return self.decoder(z)

0 comments on commit 3481421

Please sign in to comment.