Skip to content

Commit

Permalink
Merge pull request labmlai#168 from jakehsiao/patch-3
Browse files Browse the repository at this point in the history
Add activation for timed embedding and dropout for Residual block in DDPM UNet
  • Loading branch information
vpj authored Feb 17, 2023
2 parents bdaa667 + 8053f3f commit 4db39b5
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions labml_nn/diffusion/ddpm/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import torch
from torch import nn
import torch.nn.functional as F

from labml_helpers.module import Module

Expand Down Expand Up @@ -91,12 +92,13 @@ class ResidualBlock(Module):
Each resolution is processed with two residual blocks.
"""

def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32, dropout_rate: float = 0.1):
"""
* `in_channels` is the number of input channels
* `out_channels` is the number of input channels
* `time_channels` is the number channels in the time step ($t$) embeddings
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
* `dropout_rate` is the dropout rate
"""
super().__init__()
# Group normalization and the first convolution layer
Expand All @@ -118,6 +120,7 @@ def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_gr

# Linear layer for time embeddings
self.time_emb = nn.Linear(time_channels, out_channels)
self.time_act = Swish()

def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
Expand All @@ -127,9 +130,9 @@ def forward(self, x: torch.Tensor, t: torch.Tensor):
# First convolution layer
h = self.conv1(self.act1(self.norm1(x)))
# Add time embeddings
h += self.time_emb(t)[:, :, None, None]
h += self.time_emb(self.time_act(t))[:, :, None, None]
# Second convolution layer
h = self.conv2(self.act2(self.norm2(h)))
h = self.conv2(F.dropout(self.act2(self.norm2(h)), self.dropout_rate))

# Add the shortcut connection and return
return h + self.shortcut(x)
Expand Down

0 comments on commit 4db39b5

Please sign in to comment.