Skip to content

Commit b65eb37

Browse files
Merge pull request #46 from huggingface/merge_ldm_resnet
[ResNet Refactor] Merge ldm into resnet
2 parents 66ee73e + 26ce60c commit b65eb37

File tree

2 files changed

+271
-92
lines changed

2 files changed

+271
-92
lines changed

src/diffusers/models/resnet.py

Lines changed: 177 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def forward(self, x):
162162

163163
# RESNETS
164164

165-
# unet_glide.py & unet_ldm.py
165+
# unet_glide.py
166166
class ResBlock(TimestepBlock):
167167
"""
168168
A residual block that can optionally change the number of channels.
@@ -188,6 +188,7 @@ def __init__(
188188
use_checkpoint=False,
189189
up=False,
190190
down=False,
191+
overwrite=False, # TODO(Patrick) - use for glide at later stage
191192
):
192193
super().__init__()
193194
self.channels = channels
@@ -236,13 +237,76 @@ def __init__(
236237
else:
237238
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
238239

240+
self.overwrite = overwrite
241+
self.is_overwritten = False
242+
if self.overwrite:
243+
in_channels = channels
244+
out_channels = self.out_channels
245+
conv_shortcut = False
246+
dropout = 0.0
247+
temb_channels = emb_channels
248+
groups = 32
249+
pre_norm = True
250+
eps = 1e-5
251+
non_linearity = "silu"
252+
self.pre_norm = pre_norm
253+
self.in_channels = in_channels
254+
out_channels = in_channels if out_channels is None else out_channels
255+
self.out_channels = out_channels
256+
self.use_conv_shortcut = conv_shortcut
257+
258+
if self.pre_norm:
259+
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
260+
else:
261+
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
262+
263+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
264+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
265+
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
266+
self.dropout = torch.nn.Dropout(dropout)
267+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
268+
if non_linearity == "swish":
269+
self.nonlinearity = nonlinearity
270+
elif non_linearity == "mish":
271+
self.nonlinearity = Mish()
272+
elif non_linearity == "silu":
273+
self.nonlinearity = nn.SiLU()
274+
275+
if self.in_channels != self.out_channels:
276+
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
277+
278+
def set_weights(self):
279+
# TODO(Patrick): use for glide at later stage
280+
self.norm1.weight.data = self.in_layers[0].weight.data
281+
self.norm1.bias.data = self.in_layers[0].bias.data
282+
283+
self.conv1.weight.data = self.in_layers[-1].weight.data
284+
self.conv1.bias.data = self.in_layers[-1].bias.data
285+
286+
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
287+
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
288+
289+
self.norm2.weight.data = self.out_layers[0].weight.data
290+
self.norm2.bias.data = self.out_layers[0].bias.data
291+
292+
self.conv2.weight.data = self.out_layers[-1].weight.data
293+
self.conv2.bias.data = self.out_layers[-1].bias.data
294+
295+
if self.in_channels != self.out_channels:
296+
self.nin_shortcut.weight.data = self.skip_connection.weight.data
297+
self.nin_shortcut.bias.data = self.skip_connection.bias.data
298+
239299
def forward(self, x, emb):
240300
"""
241301
Apply the block to a Tensor, conditioned on a timestep embedding.
242302
243303
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
244304
:return: an [N x C x ...] Tensor of outputs.
245305
"""
306+
if self.overwrite:
307+
# TODO(Patrick): use for glide at later stage
308+
self.set_weights()
309+
246310
if self.updown:
247311
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
248312
h = in_rest(x)
@@ -251,6 +315,7 @@ def forward(self, x, emb):
251315
h = in_conv(h)
252316
else:
253317
h = self.in_layers(x)
318+
254319
emb_out = self.emb_layers(emb).type(h.dtype)
255320
while len(emb_out.shape) < len(h.shape):
256321
emb_out = emb_out[..., None]
@@ -262,7 +327,50 @@ def forward(self, x, emb):
262327
else:
263328
h = h + emb_out
264329
h = self.out_layers(h)
265-
return self.skip_connection(x) + h
330+
331+
result = self.skip_connection(x) + h
332+
333+
# TODO(Patrick) Use for glide at later stage
334+
# result = self.forward_2(x, emb)
335+
336+
return result
337+
338+
def forward_2(self, x, temb, mask=1.0):
339+
if self.overwrite and not self.is_overwritten:
340+
self.set_weights()
341+
self.is_overwritten = True
342+
343+
h = x
344+
if self.pre_norm:
345+
h = self.norm1(h)
346+
h = self.nonlinearity(h)
347+
348+
h = self.conv1(h)
349+
350+
if not self.pre_norm:
351+
h = self.norm1(h)
352+
h = self.nonlinearity(h)
353+
354+
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
355+
356+
if self.pre_norm:
357+
h = self.norm2(h)
358+
h = self.nonlinearity(h)
359+
360+
h = self.dropout(h)
361+
h = self.conv2(h)
362+
363+
if not self.pre_norm:
364+
h = self.norm2(h)
365+
h = self.nonlinearity(h)
366+
367+
if self.in_channels != self.out_channels:
368+
if self.use_conv_shortcut:
369+
x = self.conv_shortcut(x)
370+
else:
371+
x = self.nin_shortcut(x)
372+
373+
return x + h
266374

267375

268376
# unet.py and unet_grad_tts.py
@@ -280,6 +388,7 @@ def __init__(
280388
eps=1e-6,
281389
non_linearity="swish",
282390
overwrite_for_grad_tts=False,
391+
overwrite_for_ldm=False,
283392
):
284393
super().__init__()
285394
self.pre_norm = pre_norm
@@ -302,15 +411,19 @@ def __init__(
302411
self.nonlinearity = nonlinearity
303412
elif non_linearity == "mish":
304413
self.nonlinearity = Mish()
414+
elif non_linearity == "silu":
415+
self.nonlinearity = nn.SiLU()
305416

306417
if self.in_channels != self.out_channels:
307418
if self.use_conv_shortcut:
419+
# TODO(Patrick) - this branch is never used I think => can be deleted!
308420
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
309421
else:
310422
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
311423

312424
self.is_overwritten = False
313425
self.overwrite_for_grad_tts = overwrite_for_grad_tts
426+
self.overwrite_for_ldm = overwrite_for_ldm
314427
if self.overwrite_for_grad_tts:
315428
dim = in_channels
316429
dim_out = out_channels
@@ -324,6 +437,39 @@ def __init__(
324437
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
325438
else:
326439
self.res_conv = torch.nn.Identity()
440+
elif self.overwrite_for_ldm:
441+
dims = 2
442+
# eps = 1e-5
443+
# non_linearity = "silu"
444+
# overwrite_for_ldm
445+
channels = in_channels
446+
emb_channels = temb_channels
447+
use_scale_shift_norm = False
448+
449+
self.in_layers = nn.Sequential(
450+
normalization(channels, swish=1.0),
451+
nn.Identity(),
452+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
453+
)
454+
self.emb_layers = nn.Sequential(
455+
nn.SiLU(),
456+
linear(
457+
emb_channels,
458+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
459+
),
460+
)
461+
self.out_layers = nn.Sequential(
462+
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
463+
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
464+
nn.Dropout(p=dropout),
465+
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
466+
)
467+
if self.out_channels == in_channels:
468+
self.skip_connection = nn.Identity()
469+
# elif use_conv:
470+
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
471+
else:
472+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
327473

328474
def set_weights_grad_tts(self):
329475
self.conv1.weight.data = self.block1.block[0].weight.data
@@ -343,13 +489,36 @@ def set_weights_grad_tts(self):
343489
self.nin_shortcut.weight.data = self.res_conv.weight.data
344490
self.nin_shortcut.bias.data = self.res_conv.bias.data
345491

346-
def forward(self, x, temb, mask=None):
492+
def set_weights_ldm(self):
493+
self.norm1.weight.data = self.in_layers[0].weight.data
494+
self.norm1.bias.data = self.in_layers[0].bias.data
495+
496+
self.conv1.weight.data = self.in_layers[-1].weight.data
497+
self.conv1.bias.data = self.in_layers[-1].bias.data
498+
499+
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
500+
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
501+
502+
self.norm2.weight.data = self.out_layers[0].weight.data
503+
self.norm2.bias.data = self.out_layers[0].bias.data
504+
505+
self.conv2.weight.data = self.out_layers[-1].weight.data
506+
self.conv2.bias.data = self.out_layers[-1].bias.data
507+
508+
if self.in_channels != self.out_channels:
509+
self.nin_shortcut.weight.data = self.skip_connection.weight.data
510+
self.nin_shortcut.bias.data = self.skip_connection.bias.data
511+
512+
def forward(self, x, temb, mask=1.0):
347513
if self.overwrite_for_grad_tts and not self.is_overwritten:
348514
self.set_weights_grad_tts()
349515
self.is_overwritten = True
516+
elif self.overwrite_for_ldm and not self.is_overwritten:
517+
self.set_weights_ldm()
518+
self.is_overwritten = True
350519

351520
h = x
352-
h = h * mask if mask is not None else h
521+
h = h * mask
353522
if self.pre_norm:
354523
h = self.norm1(h)
355524
h = self.nonlinearity(h)
@@ -359,11 +528,11 @@ def forward(self, x, temb, mask=None):
359528
if not self.pre_norm:
360529
h = self.norm1(h)
361530
h = self.nonlinearity(h)
362-
h = h * mask if mask is not None else h
531+
h = h * mask
363532

364533
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
365534

366-
h = h * mask if mask is not None else h
535+
h = h * mask
367536
if self.pre_norm:
368537
h = self.norm2(h)
369538
h = self.nonlinearity(h)
@@ -374,9 +543,9 @@ def forward(self, x, temb, mask=None):
374543
if not self.pre_norm:
375544
h = self.norm2(h)
376545
h = self.nonlinearity(h)
377-
h = h * mask if mask is not None else h
546+
h = h * mask
378547

379-
x = x * mask if mask is not None else x
548+
x = x * mask
380549
if self.in_channels != self.out_channels:
381550
if self.use_conv_shortcut:
382551
x = self.conv_shortcut(x)

0 commit comments

Comments
 (0)