Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stylegan implementation #19

Merged
merged 14 commits into from
Nov 7, 2024
Prev Previous commit
Next Next commit
Generator 3D operational
  • Loading branch information
szmazurek committed Nov 2, 2024
commit 9cfb2c9b152b8f5cdf4c1471bab6482fe3a2d7e5
60 changes: 47 additions & 13 deletions gandlf_synth/models/architectures/stylegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, in_features: int, out_features: int):
nn.init.zeros_(self.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
print(f"WeightScaledLinear: x.shape={x.shape}")
return self.linear(x * self.scale) + self.bias


Expand Down Expand Up @@ -87,7 +86,7 @@ def __init__(self, norm_layer: nn.Module, channels: int, w_dim: int):
scales and biases it using the style vector w.

Args:
norm_layer (nn.Module): Normalization layer to be used (either 2D or 3DInstanceNorm layer).
norm_layer (nn.Module): Normalization layer to be used (either 2D or 3D InstanceNorm layer).
channels (int): Number of channels in the input tensor.
w_dim (int): Dimensionality of the intermediate latent space w.
"""
Expand All @@ -101,24 +100,39 @@ def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
x = self.instance_norm(x)
style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
if x.ndim == 5:
style_scale = style_scale.permute(0, 4, 1, 2, 3)
style_bias = style_bias.permute(0, 4, 1, 2, 3)
return style_scale * x + style_bias


class LearnableNoiseInjector(nn.Module):
def __init__(self, channels: int):
def __init__(self, n_dimensions: int, channels: int):
"""
Learnable noise injector layer. Adds the noise tensor to the input tensor and,
controlling the amount of noise added using a learnable weight parameter.

Args:
n_dimensions (int): Number of dimensions of the input tensor.
channels (int): Number of channels in the input tensor.
"""

super().__init__()
self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))
weight_size = (
(1, channels, 1, 1) if n_dimensions == 2 else (1, channels, 1, 1, 1)
)
self.n_dimensions = n_dimensions
self.weight = nn.Parameter(torch.zeros(weight_size))

def forward(self, x: torch.Tensor) -> torch.Tensor:
noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
if self.n_dimensions == 2:
noise = torch.randn(
(x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device
)
else:
noise = torch.randn(
(x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]), device=x.device
)
return x + self.weight + noise


Expand Down Expand Up @@ -154,7 +168,12 @@ def __init__(
nn.init.zeros_(self.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
bias_view_params = (
(1, self.bias.shape[0], 1, 1)
if x.ndim == 4
else (1, self.bias.shape[0], 1, 1, 1)
)
return self.conv(x * self.scale) + self.bias.view(bias_view_params)


class ConvBlock(nn.Module):
Expand Down Expand Up @@ -184,6 +203,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class GeneratorBlock(nn.Module):
def __init__(
self,
n_dimensions: int,
conv_layer: nn.Module,
norm_layer: nn.Module,
in_channel: int,
Expand All @@ -195,6 +215,7 @@ def __init__(
convolutional layers with adaptive instance normalization and leaky ReLU.

Args:
n_dimensions (int): Number of dimensions of the input images.
conv_layer (nn.Module): Convolutional layer to be used.
norm_layer (nn.Module): Normalization layer to be used.
in_channel (int): Number of input channels in the first convolutional layer.
Expand All @@ -205,8 +226,8 @@ def __init__(
self.conv1 = WeightScaledConv(conv_layer, in_channel, out_channel)
self.conv2 = WeightScaledConv(conv_layer, out_channel, out_channel)
self.leaky = nn.LeakyReLU(0.2, inplace=True)
self.inject_noise1 = LearnableNoiseInjector(out_channel)
self.inject_noise2 = LearnableNoiseInjector(out_channel)
self.inject_noise1 = LearnableNoiseInjector(n_dimensions, out_channel)
self.inject_noise2 = LearnableNoiseInjector(n_dimensions, out_channel)
self.adain1 = AdaptiveInstanceNormalization(norm_layer, out_channel, w_dim)
self.adain2 = AdaptiveInstanceNormalization(norm_layer, out_channel, w_dim)

Expand All @@ -219,6 +240,7 @@ def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
class StyleGanGenerator(nn.Module):
def __init__(
self,
n_dimensions: int,
conv_layer: nn.Module,
norm_layer: nn.Module,
z_dim: int,
Expand All @@ -231,6 +253,7 @@ def __init__(
StyleGAN generator network. Generates images from the latent vector z.

Args:
n_dimensions (int): Number of dimensions of the input images.
conv_layer (nn.Module): Convolutional layer to be used.
norm_layer (nn.Module): Normalization layer to be used.
z_dim (int): Dimensionality of the latent vector z.
Expand All @@ -241,16 +264,19 @@ def __init__(
channels in consecutive convolutional layers.
"""
super().__init__()
self.starting_cte = nn.Parameter(torch.ones(1, in_channels, 4, 4))
cte_shape = (
(1, in_channels, 4, 4) if n_dimensions == 2 else (1, in_channels, 4, 4, 4)
)
self.starting_cte = nn.Parameter(torch.ones(cte_shape))
self.map = MappingNetwork(z_dim, w_dim)
self.initial_adain1 = AdaptiveInstanceNormalization(
norm_layer, in_channels, w_dim
)
self.initial_adain2 = AdaptiveInstanceNormalization(
norm_layer, in_channels, w_dim
)
self.initial_noise1 = LearnableNoiseInjector(in_channels)
self.initial_noise2 = LearnableNoiseInjector(in_channels)
self.initial_noise1 = LearnableNoiseInjector(n_dimensions, in_channels)
self.initial_noise2 = LearnableNoiseInjector(n_dimensions, in_channels)
self.initial_conv = conv_layer(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
Expand All @@ -268,7 +294,9 @@ def __init__(
conv_in_c = int(in_channels * progressive_layers_scaling_factors[i])
conv_out_c = int(in_channels * progressive_layers_scaling_factors[i + 1])
self.prog_blocks.append(
GeneratorBlock(conv_layer, norm_layer, conv_in_c, conv_out_c, w_dim)
GeneratorBlock(
n_dimensions, conv_layer, norm_layer, conv_in_c, conv_out_c, w_dim
)
)
self.rgb_layers.append(
WeightScaledConv(
Expand All @@ -295,12 +323,17 @@ def fade_in(
return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

def forward(self, noise: torch.Tensor, alpha: float, steps: int) -> torch.Tensor:
print(f"noise.shape: {noise.shape}")
w = self.map(noise)
print(f"w.shape: {w.shape}")
x = self.initial_adain1(self.initial_noise1(self.starting_cte), w)
print(f"x.shape: {x.shape}")
x = self.initial_conv(x)
print(f"x.shape.initial_conv: {x.shape}")
out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

print(f"out.shape: {out.shape}")
if steps == 0:
print(f"steps == 0")
return self.initial_rgb(x)

for step in range(steps):
Expand Down Expand Up @@ -418,6 +451,7 @@ class StyleGan(ModelBase):
def __init__(self, model_config: Type[AbstractModelConfig]):
ModelBase.__init__(self, model_config)
self.generator = StyleGanGenerator(
self.n_dimensions,
self.Conv,
self.InstanceNorm,
model_config.architecture["latent_vector_size"],
Expand Down
21 changes: 10 additions & 11 deletions gandlf_synth/models/modules/stylegan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ def _gradient_penalty(

return gradient_penalty

def _generate_latent_vector(self, batch_size: int) -> torch.Tensor:
latent_vector = torch.randn(
(batch_size, self.model_config.architecture["latent_vector_size"]),
device=self.device,
)
if self.model_config.n_dimensions == 3:
latent_vector = latent_vector.unsqueeze(1)
return latent_vector

def training_step(self, batch: object, batch_idx: int) -> torch.Tensor:
real_images: torch.Tensor = batch
real_images = self._resize_to_current_step_demands(real_images)
Expand All @@ -173,19 +182,9 @@ def training_step(self, batch: object, batch_idx: int) -> torch.Tensor:
gradient_clip_algorithm = self.model_config.gradient_clip_algorithm

batch_size = real_images.shape[0]
latent_vector = (
generate_latent_vector(
batch_size,
self.model_config.architecture["latent_vector_size"],
self.model_config.n_dimensions,
self.device,
)
.type_as(real_images)
.squeeze(2, 3)
)
latent_vector = self._generate_latent_vector(batch_size)
loss_disc, loss_gen = self.losses["disc_loss"], self.losses["gen_loss"]
optimizer_disc, optimizer_gen = self.optimizers()

fake_images = self.model(latent_vector, self.alpha, self.current_step)
disc_preds_on_real = self.model.discriminator(
real_images, self.alpha, self.current_step
Expand Down