Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stylegan implementation #19

Merged
merged 14 commits into from
Nov 7, 2024
Prev Previous commit
Next Next commit
Porting stylegan to framework standards
  • Loading branch information
szmazurek committed Nov 2, 2024
commit ec226ed6f94282e1d91c25b8cdb48b4ab4e39458
184 changes: 142 additions & 42 deletions gandlf_synth/models/architectures/stylegan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from typing import List, Type

from gandlf_synth.models.architectures.base_model import ModelBase
from gandlf_synth.models.configs.config_abc import AbstractModelConfig


class WeightScaledLinear(nn.Module):
Expand Down Expand Up @@ -76,14 +79,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class AdaptiveInstanceNormalization(nn.Module):
def __init__(self, channels: int, w_dim: int):
def __init__(self, norm_layer: nn.Module, channels: int, w_dim: int):
"""
Adaptive instance normalization layer. Applies instance normalization to the input
tensor and then scales and biases it using the style vector w.
Adaptive instance normalization layer.
Applies instance normalization to the input tensor and then
scales and biases it using the style vector w.

Args:
norm_layer (nn.Module): Normalization layer to be used (either 2D or 3DInstanceNorm layer).
channels (int): Number of channels in the input tensor.
w_dim (int): Dimensionality of the intermediate latent space w.
"""

super().__init__()
self.instance_norm = nn.InstanceNorm2d(channels)
self.instance_norm = norm_layer(channels)
self.style_scale = WeightScaledLinear(w_dim, channels)
self.style_bias = WeightScaledLinear(w_dim, channels)

Expand Down Expand Up @@ -112,9 +121,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.weight + noise


class WeightScaledConv2d(nn.Module):
class WeightScaledConv(nn.Module):
def __init__(
self,
conv_layer: nn.Module,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
Expand All @@ -126,14 +136,15 @@ def __init__(
number of input channels before the convolution operation.

Args:
conv_layer (nn.Module): Convolutional layer to be used.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Size of the convolutional kernel.
stride (int): Stride of the convolution operation.
padding (int): Padding of the input tensor.
"""
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.conv = conv_layer(in_channels, out_channels, kernel_size, stride, padding)
self.scale = (2 / (in_channels * (kernel_size**2))) ** 0.5
self.bias = self.conv.bias
self.conv.bias = None
Expand All @@ -146,19 +157,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class ConvBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
def __init__(self, conv_layer: nn.Module, in_channels: int, out_channels: int):
"""
Convolutional block for the discriminator network. Consists of two weight-scaled
convolutional layers with leaky ReLU activation functions.
Convolutional block for the discriminator network. Consists of
two weight-scaled convolutional layers with leaky ReLU activation
functions.

Args:
conv_layer (nn.Module): Convolutional layer to be used.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""

super().__init__()
self.conv1 = WeightScaledConv2d(in_channels, out_channels)
self.conv2 = WeightScaledConv2d(out_channels, out_channels)
self.conv1 = WeightScaledConv(conv_layer, in_channels, out_channels)
self.conv2 = WeightScaledConv(conv_layer, out_channels, out_channels)
self.leaky = nn.LeakyReLU(0.2)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -168,25 +181,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class GeneratorBlock(nn.Module):
def __init__(self, in_channel: int, out_channel: int, w_dim: int):
def __init__(
self,
conv_layer: nn.Module,
norm_layer: nn.Module,
in_channel: int,
out_channel: int,
w_dim: int,
):
"""
Generator block for the generator network. Consists of two weight-scaled
convolutional layers with adaptive instance normalization and leaky ReLU.

Args:
in_channel (int): Number of input channels.
out_channel (int): Number of output channels.
conv_layer (nn.Module): Convolutional layer to be used.
norm_layer (nn.Module): Normalization layer to be used.
in_channel (int): Number of input channels in the first convolutional layer.
out_channel (int): Number of channels in the generated image.
w_dim (int): Dimensionality of the intermediate latent space w.

"""
super().__init__()
self.conv1 = WeightScaledConv2d(in_channel, out_channel)
self.conv2 = WeightScaledConv2d(out_channel, out_channel)
self.conv1 = WeightScaledConv(conv_layer, in_channel, out_channel)
self.conv2 = WeightScaledConv(conv_layer, out_channel, out_channel)
self.leaky = nn.LeakyReLU(0.2, inplace=True)
self.inject_noise1 = LearnableNoiseInjector(out_channel)
self.inject_noise2 = LearnableNoiseInjector(out_channel)
self.adain1 = AdaptiveInstanceNormalization(out_channel, w_dim)
self.adain2 = AdaptiveInstanceNormalization(out_channel, w_dim)
self.adain1 = AdaptiveInstanceNormalization(norm_layer, out_channel, w_dim)
self.adain2 = AdaptiveInstanceNormalization(norm_layer, out_channel, w_dim)

def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
Expand All @@ -197,6 +218,8 @@ def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
class StyleGanGenerator(nn.Module):
def __init__(
self,
conv_layer: nn.Module,
norm_layer: nn.Module,
z_dim: int,
w_dim: int,
in_channels: int,
Expand All @@ -207,27 +230,33 @@ def __init__(
StyleGAN generator network. Generates images from the latent vector z.

Args:
conv_layer (nn.Module): Convolutional layer to be used.
norm_layer (nn.Module): Normalization layer to be used.
z_dim (int): Dimensionality of the latent vector z.
w_dim (int): Dimensionality of the intermediate latent space w.
in_channels (int): Number of input channels.
in_channel (int): Number of input channels in the first convolutional layer.
img_channels (int): Number of output channels.
progressive_layers_scaling_factors (List[float]): List of scaling factors for
channels in consecutive convolutional layers.
"""
super().__init__()
self.starting_cte = nn.Parameter(torch.ones(1, in_channels, 4, 4))
self.map = MappingNetwork(z_dim, w_dim)
self.initial_adain1 = AdaptiveInstanceNormalization(in_channels, w_dim)
self.initial_adain2 = AdaptiveInstanceNormalization(in_channels, w_dim)
self.initial_adain1 = AdaptiveInstanceNormalization(
norm_layer, in_channels, w_dim
)
self.initial_adain2 = AdaptiveInstanceNormalization(
norm_layer, in_channels, w_dim
)
self.initial_noise1 = LearnableNoiseInjector(in_channels)
self.initial_noise2 = LearnableNoiseInjector(in_channels)
self.initial_conv = nn.Conv2d(
self.initial_conv = conv_layer(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
self.leaky = nn.LeakyReLU(0.2, inplace=True)

self.initial_rgb = WeightScaledConv2d(
in_channels, img_channels, kernel_size=1, stride=1, padding=0
self.initial_rgb = WeightScaledConv(
conv_layer, in_channels, img_channels, kernel_size=1, stride=1, padding=0
)
self.prog_blocks, self.rgb_layers = (
nn.ModuleList([]),
Expand All @@ -237,9 +266,11 @@ def __init__(
for i in range(len(progressive_layers_scaling_factors) - 1):
conv_in_c = int(in_channels * progressive_layers_scaling_factors[i])
conv_out_c = int(in_channels * progressive_layers_scaling_factors[i + 1])
self.prog_blocks.append(GeneratorBlock(conv_in_c, conv_out_c, w_dim))
self.prog_blocks.append(
GeneratorBlock(conv_layer, norm_layer, conv_in_c, conv_out_c, w_dim)
)
self.rgb_layers.append(
WeightScaledConv2d(
WeightScaledConv(
conv_out_c, img_channels, kernel_size=1, stride=1, padding=0
)
)
Expand Down Expand Up @@ -279,6 +310,9 @@ def forward(self, noise: torch.Tensor, alpha: float, steps: int) -> torch.Tensor
class StyleGanDiscriminator(nn.Module):
def __init__(
self,
conv_layer: nn.Module,
norm_layer: nn.Module,
pool_layer: nn.Module,
in_channels: int,
img_channels: int,
progressive_layers_scaling_factors: List[float],
Expand All @@ -287,6 +321,9 @@ def __init__(
StyleGAN discriminator network.

Args:
conv_layer (nn.Module): Convolutional layer to be used.
norm_layer (nn.Module): Normalization layer to be used.
pool_layer (nn.Module): Pooling layer to be used.
in_channels (int): Number of input channels.
img_channels (int): Number of output channels.
progressive_layers_scaling_factors (List[float]): List of scaling factors for
Expand All @@ -298,28 +335,37 @@ def __init__(
for i in range(len(progressive_layers_scaling_factors) - 1, 0, -1):
conv_in = int(in_channels * progressive_layers_scaling_factors[i])
conv_out = int(in_channels * progressive_layers_scaling_factors[i - 1])
self.prog_blocks.append(ConvBlock(conv_in, conv_out))
self.prog_blocks.append(ConvBlock(conv_layer, conv_in, conv_out))
self.rgb_layers.append(
WeightScaledConv2d(
img_channels, conv_in, kernel_size=1, stride=1, padding=0
WeightScaledConv(
conv_layer,
img_channels,
conv_in,
kernel_size=1,
stride=1,
padding=0,
)
)

self.initial_conv = WeightScaledConv2d(
img_channels, in_channels, kernel_size=1, stride=1, padding=0
self.initial_conv = WeightScaledConv(
conv_layer, img_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.rgb_layers.append(self.initial_rgb)
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
self.rgb_layers.append(self.initial_conv)
self.avg_pool = pool_layer(kernel_size=2, stride=2)

self.final_block = nn.Sequential(
# +1 to in_channels because we concatenate from MiniBatch std
WeightScaledConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
WeightScaledConv(
conv_layer, in_channels + 1, in_channels, kernel_size=3, padding=1
),
nn.LeakyReLU(0.2),
WeightScaledConv2d(
in_channels, in_channels, kernel_size=4, padding=0, stride=1
WeightScaledConv(
conv_layer, in_channels, in_channels, kernel_size=4, padding=0, stride=1
),
nn.LeakyReLU(0.2),
WeightScaledConv2d(in_channels, 1, kernel_size=1, padding=0, stride=1),
WeightScaledConv(
conv_layer, in_channels, 1, kernel_size=1, padding=0, stride=1
),
)

def fade_in(self, alpha: float, downscaled: torch.Tensor, out: torch.Tensor):
Expand All @@ -334,7 +380,7 @@ def fade_in(self, alpha: float, downscaled: torch.Tensor, out: torch.Tensor):

return alpha * out + (1 - alpha) * downscaled

def minibatch_std(self, x: torch.Tensor) -> torch.Tensor:
def calculate_minibatch_std(self, x: torch.Tensor) -> torch.Tensor:
"""
Computes the minibatch standard deviation of the input tensor and concatenates it
to the input tensor.
Expand All @@ -349,7 +395,7 @@ def forward(self, x: torch.Tensor, alpha: float, steps: int) -> torch.Tensor:
out = self.leaky(self.rgb_layers[cur_step](x))

if steps == 0:
out = self.minibatch_std(out)
out = self.calculate_minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)

downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
Expand All @@ -360,5 +406,59 @@ def forward(self, x: torch.Tensor, alpha: float, steps: int) -> torch.Tensor:
out = self.prog_blocks[step](out)
out = self.avg_pool(out)

out = self.minibatch_std(out)
out = self.calculate_minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)


class StyleGan(ModelBase):
def __init__(self, model_config: Type[AbstractModelConfig]):
ModelBase.__init__(self, model_config)
self.generator = StyleGanGenerator(
self.Conv,
self.Norm,
model_config.z_dim,
model_config.w_dim,
model_config.in_channels,
model_config.progressive_layers_scaling_factors,
)
self.discriminator = StyleGanDiscriminator(
self.Conv,
self.Norm,
self.AvgPool,
model_config.in_channels,
self.n_channels,
model_config.progressive_layers_scaling_factors,
)

def generator_forward(self, noise: torch.Tensor, alpha: float, steps: int):
"""
Forward pass of the generator network.

Args:
noise (torch.Tensor): Latent vector z.
alpha (float): Alpha parameter for fading in the images.
steps (int): Number of steps in the progressive growing of the GAN.
"""
return self.generator(noise, alpha, steps)

def discriminator_forward(self, x: torch.Tensor, alpha: float, steps: int):
"""
Forward pass of the discriminator network.

Args:
x (torch.Tensor): Input image tensor.
alpha (float): Alpha parameter for fading in the images.
steps (int): Number of steps in the progressive growing of the GAN.
"""
return self.discriminator(x, alpha, steps)

def forward(self, noise: torch.Tensor, alpha: float, steps: int):
"""
Forward pass of the StyleGAN network, defined as the forward pass of the generator network.

Args:
noise (torch.Tensor): Latent vector z.
alpha (float): Alpha parameter for fading in the images.
steps (int): Number of steps in the progressive growing of the GAN.
"""
return self.generator_forward(noise, alpha, steps)