Skip to content

Commit

Permalink
added data parallel opt
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 24, 2020
1 parent f93b8c4 commit 8759cc3
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
parser.add_argument('--betas', type=tuple, default=(0.5, 0.999), help='Betas for Adam optimizer')
parser.add_argument('--n_epochs', type=int, default=200, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
parser.add_argument('--data_parallel', action="store_true", default=False, help='train with data parallel')

# logging
parser.add_argument('--log_dir', type=str, default='vae/logs', help='Path to where log files will be saved')
Expand Down Expand Up @@ -67,14 +68,12 @@

# initialize model, instantiate opt & scheduler & loss fn
if args.plus:
# model = torch.nn.DataParallel(
# VAE_Plus(args.z_dim, args.model_dim, args.img_size, args.img_channels),
# device_ids=args.device_ids).to(device)
model = VAE_Plus(args.z_dim, args.model_dim, args.img_size, args.img_channels).to(device)
else:
model = torch.nn.DataParallel(
VAE(args.z_dim, args.model_dim, args.img_size, args.img_channels, args.n_res_blocks),
device_ids=args.device_ids).to(device)
model = VAE(args.z_dim, args.model_dim, args.img_size, args.img_channels, args.n_res_blocks).to(device)

if args.data_parallel:
model = torch.nn.DataParallel(model, device_ids=args.device_ids).to(device)

model.apply(initialize_modules)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas)
Expand Down Expand Up @@ -137,7 +136,10 @@
# decode fixed z latent vectors
model.eval()
with torch.no_grad():
sampled_images = model.module.sample(fixed_z)
if args.data_parallel:
sampled_images = model.module.sample(fixed_z)
else:
sampled_images = model.sample(fixed_z)
sampled_images = (sampled_images + 1) / 2

# log images and losses & save model parameters
Expand Down

0 comments on commit 8759cc3

Please sign in to comment.