Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 177 additions & 8 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def forward(self, x):

# RESNETS

# unet_glide.py & unet_ldm.py
# unet_glide.py
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
Expand All @@ -188,6 +188,7 @@ def __init__(
use_checkpoint=False,
up=False,
down=False,
overwrite=False, # TODO(Patrick) - use for glide at later stage
):
super().__init__()
self.channels = channels
Expand Down Expand Up @@ -236,13 +237,76 @@ def __init__(
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

self.overwrite = overwrite
self.is_overwritten = False
if self.overwrite:
in_channels = channels
out_channels = self.out_channels
conv_shortcut = False
dropout = 0.0
temb_channels = emb_channels
groups = 32
pre_norm = True
eps = 1e-5
non_linearity = "silu"
self.pre_norm = pre_norm
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut

if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)

self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()

if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

def set_weights(self):
# TODO(Patrick): use for glide at later stage
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data

self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data

self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data

self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data

self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data

if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data

def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.

:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if self.overwrite:
# TODO(Patrick): use for glide at later stage
self.set_weights()

if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
Expand All @@ -251,6 +315,7 @@ def forward(self, x, emb):
h = in_conv(h)
else:
h = self.in_layers(x)

emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
Expand All @@ -262,7 +327,50 @@ def forward(self, x, emb):
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h

result = self.skip_connection(x) + h

# TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb)

return result

def forward_2(self, x, temb, mask=1.0):
if self.overwrite and not self.is_overwritten:
self.set_weights()
self.is_overwritten = True

h = x
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)

h = self.conv1(h)

if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)

h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]

if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)

h = self.dropout(h)
h = self.conv2(h)

if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)

if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)

return x + h


# unet.py and unet_grad_tts.py
Expand All @@ -280,6 +388,7 @@ def __init__(
eps=1e-6,
non_linearity="swish",
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
):
super().__init__()
self.pre_norm = pre_norm
Expand All @@ -302,15 +411,19 @@ def __init__(
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()

if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
# TODO(Patrick) - this branch is never used I think => can be deleted!
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

self.is_overwritten = False
self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
Expand All @@ -324,6 +437,39 @@ def __init__(
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm:
dims = 2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False

self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == in_channels:
self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
Expand All @@ -343,13 +489,36 @@ def set_weights_grad_tts(self):
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data

def forward(self, x, temb, mask=None):
def set_weights_ldm(self):
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data

self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data

self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data

self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data

self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data

if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data

def forward(self, x, temb, mask=1.0):
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True

h = x
h = h * mask if mask is not None else h
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
Expand All @@ -359,11 +528,11 @@ def forward(self, x, temb, mask=None):
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask if mask is not None else h
h = h * mask

h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]

h = h * mask if mask is not None else h
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
Expand All @@ -374,9 +543,9 @@ def forward(self, x, temb, mask=None):
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask if mask is not None else h
h = h * mask

x = x * mask if mask is not None else x
x = x * mask
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
Expand Down
Loading