diff --git a/vae/model.py b/vae/model.py index d5e6f2c..6b7915e 100644 --- a/vae/model.py +++ b/vae/model.py @@ -1,6 +1,6 @@ from torch import nn import torch -from networks.layers import ConvNormAct, ResBlock, SA_Conv2d +from networks.layers import ConvNormAct, ResBlock, SA_Conv2d, SN_ConvTranspose2d class VAE(nn.Module): @@ -29,8 +29,8 @@ def __init__(self, z_dim, model_dim, img_size, img_channels, n_res_blocks=0): ConvNormAct(model_dim * 8, model_dim * 4, 'basic', 'up', activation='lrelu'), ConvNormAct(model_dim * 4, model_dim * 2, 'basic', 'up', activation='lrelu'), ConvNormAct(model_dim * 2, model_dim, 'basic', 'up', activation='lrelu'), - ConvNormAct(model_dim, img_channels, 'basic', 'up', activation='lrelu'), - nn.Sigmoid() + nn.ConvTranspose2d(model_dim, img_channels, kernel_size=4, stride=2, padding=1), + nn.Tanh() ) def reparameterize(self, mu, logvar): @@ -79,7 +79,7 @@ def __init__(self, z_dim, model_dim, img_size, img_channels): ConvNormAct(model_dim * 4, model_dim * 2, 'sn', 'up', activation='lrelu'), ConvNormAct(model_dim * 2, model_dim, 'sn', 'up', activation='lrelu'), SA_Conv2d(model_dim), - ConvNormAct(model_dim, img_channels, 'sn', 'up', activation='lrelu'), + SN_ConvTranspose2d(in_channels=model_dim, out_channels=img_channels, kernel_size=4, stride=2, padding=1), nn.Tanh() )