Skip to content

Commit

Permalink
change last conv layer to not have activation for generators
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 25, 2020
1 parent 258df15 commit 34bd250
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions vae/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch import nn
import torch
from networks.layers import ConvNormAct, ResBlock, SA_Conv2d
from networks.layers import ConvNormAct, ResBlock, SA_Conv2d, SN_ConvTranspose2d


class VAE(nn.Module):
Expand Down Expand Up @@ -29,8 +29,8 @@ def __init__(self, z_dim, model_dim, img_size, img_channels, n_res_blocks=0):
ConvNormAct(model_dim * 8, model_dim * 4, 'basic', 'up', activation='lrelu'),
ConvNormAct(model_dim * 4, model_dim * 2, 'basic', 'up', activation='lrelu'),
ConvNormAct(model_dim * 2, model_dim, 'basic', 'up', activation='lrelu'),
ConvNormAct(model_dim, img_channels, 'basic', 'up', activation='lrelu'),
nn.Sigmoid()
nn.ConvTranspose2d(model_dim, img_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)

def reparameterize(self, mu, logvar):
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(self, z_dim, model_dim, img_size, img_channels):
ConvNormAct(model_dim * 4, model_dim * 2, 'sn', 'up', activation='lrelu'),
ConvNormAct(model_dim * 2, model_dim, 'sn', 'up', activation='lrelu'),
SA_Conv2d(model_dim),
ConvNormAct(model_dim, img_channels, 'sn', 'up', activation='lrelu'),
SN_ConvTranspose2d(in_channels=model_dim, out_channels=img_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)

Expand Down

0 comments on commit 34bd250

Please sign in to comment.