@@ -162,7 +162,7 @@ def forward(self, x):
162162
163163# RESNETS
164164
165- # unet_glide.py & unet_ldm.py
165+ # unet_glide.py
166166class ResBlock (TimestepBlock ):
167167 """
168168 A residual block that can optionally change the number of channels.
@@ -188,6 +188,7 @@ def __init__(
188188 use_checkpoint = False ,
189189 up = False ,
190190 down = False ,
191+ overwrite = False , # TODO(Patrick) - use for glide at later stage
191192 ):
192193 super ().__init__ ()
193194 self .channels = channels
@@ -236,13 +237,76 @@ def __init__(
236237 else :
237238 self .skip_connection = conv_nd (dims , channels , self .out_channels , 1 )
238239
240+ self .overwrite = overwrite
241+ self .is_overwritten = False
242+ if self .overwrite :
243+ in_channels = channels
244+ out_channels = self .out_channels
245+ conv_shortcut = False
246+ dropout = 0.0
247+ temb_channels = emb_channels
248+ groups = 32
249+ pre_norm = True
250+ eps = 1e-5
251+ non_linearity = "silu"
252+ self .pre_norm = pre_norm
253+ self .in_channels = in_channels
254+ out_channels = in_channels if out_channels is None else out_channels
255+ self .out_channels = out_channels
256+ self .use_conv_shortcut = conv_shortcut
257+
258+ if self .pre_norm :
259+ self .norm1 = Normalize (in_channels , num_groups = groups , eps = eps )
260+ else :
261+ self .norm1 = Normalize (out_channels , num_groups = groups , eps = eps )
262+
263+ self .conv1 = torch .nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
264+ self .temb_proj = torch .nn .Linear (temb_channels , out_channels )
265+ self .norm2 = Normalize (out_channels , num_groups = groups , eps = eps )
266+ self .dropout = torch .nn .Dropout (dropout )
267+ self .conv2 = torch .nn .Conv2d (out_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
268+ if non_linearity == "swish" :
269+ self .nonlinearity = nonlinearity
270+ elif non_linearity == "mish" :
271+ self .nonlinearity = Mish ()
272+ elif non_linearity == "silu" :
273+ self .nonlinearity = nn .SiLU ()
274+
275+ if self .in_channels != self .out_channels :
276+ self .nin_shortcut = torch .nn .Conv2d (in_channels , out_channels , kernel_size = 1 , stride = 1 , padding = 0 )
277+
278+ def set_weights (self ):
279+ # TODO(Patrick): use for glide at later stage
280+ self .norm1 .weight .data = self .in_layers [0 ].weight .data
281+ self .norm1 .bias .data = self .in_layers [0 ].bias .data
282+
283+ self .conv1 .weight .data = self .in_layers [- 1 ].weight .data
284+ self .conv1 .bias .data = self .in_layers [- 1 ].bias .data
285+
286+ self .temb_proj .weight .data = self .emb_layers [- 1 ].weight .data
287+ self .temb_proj .bias .data = self .emb_layers [- 1 ].bias .data
288+
289+ self .norm2 .weight .data = self .out_layers [0 ].weight .data
290+ self .norm2 .bias .data = self .out_layers [0 ].bias .data
291+
292+ self .conv2 .weight .data = self .out_layers [- 1 ].weight .data
293+ self .conv2 .bias .data = self .out_layers [- 1 ].bias .data
294+
295+ if self .in_channels != self .out_channels :
296+ self .nin_shortcut .weight .data = self .skip_connection .weight .data
297+ self .nin_shortcut .bias .data = self .skip_connection .bias .data
298+
239299 def forward (self , x , emb ):
240300 """
241301 Apply the block to a Tensor, conditioned on a timestep embedding.
242302
243303 :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
244304 :return: an [N x C x ...] Tensor of outputs.
245305 """
306+ if self .overwrite :
307+ # TODO(Patrick): use for glide at later stage
308+ self .set_weights ()
309+
246310 if self .updown :
247311 in_rest , in_conv = self .in_layers [:- 1 ], self .in_layers [- 1 ]
248312 h = in_rest (x )
@@ -251,6 +315,7 @@ def forward(self, x, emb):
251315 h = in_conv (h )
252316 else :
253317 h = self .in_layers (x )
318+
254319 emb_out = self .emb_layers (emb ).type (h .dtype )
255320 while len (emb_out .shape ) < len (h .shape ):
256321 emb_out = emb_out [..., None ]
@@ -262,7 +327,50 @@ def forward(self, x, emb):
262327 else :
263328 h = h + emb_out
264329 h = self .out_layers (h )
265- return self .skip_connection (x ) + h
330+
331+ result = self .skip_connection (x ) + h
332+
333+ # TODO(Patrick) Use for glide at later stage
334+ # result = self.forward_2(x, emb)
335+
336+ return result
337+
338+ def forward_2 (self , x , temb , mask = 1.0 ):
339+ if self .overwrite and not self .is_overwritten :
340+ self .set_weights ()
341+ self .is_overwritten = True
342+
343+ h = x
344+ if self .pre_norm :
345+ h = self .norm1 (h )
346+ h = self .nonlinearity (h )
347+
348+ h = self .conv1 (h )
349+
350+ if not self .pre_norm :
351+ h = self .norm1 (h )
352+ h = self .nonlinearity (h )
353+
354+ h = h + self .temb_proj (self .nonlinearity (temb ))[:, :, None , None ]
355+
356+ if self .pre_norm :
357+ h = self .norm2 (h )
358+ h = self .nonlinearity (h )
359+
360+ h = self .dropout (h )
361+ h = self .conv2 (h )
362+
363+ if not self .pre_norm :
364+ h = self .norm2 (h )
365+ h = self .nonlinearity (h )
366+
367+ if self .in_channels != self .out_channels :
368+ if self .use_conv_shortcut :
369+ x = self .conv_shortcut (x )
370+ else :
371+ x = self .nin_shortcut (x )
372+
373+ return x + h
266374
267375
268376# unet.py and unet_grad_tts.py
@@ -280,6 +388,7 @@ def __init__(
280388 eps = 1e-6 ,
281389 non_linearity = "swish" ,
282390 overwrite_for_grad_tts = False ,
391+ overwrite_for_ldm = False ,
283392 ):
284393 super ().__init__ ()
285394 self .pre_norm = pre_norm
@@ -302,15 +411,19 @@ def __init__(
302411 self .nonlinearity = nonlinearity
303412 elif non_linearity == "mish" :
304413 self .nonlinearity = Mish ()
414+ elif non_linearity == "silu" :
415+ self .nonlinearity = nn .SiLU ()
305416
306417 if self .in_channels != self .out_channels :
307418 if self .use_conv_shortcut :
419+ # TODO(Patrick) - this branch is never used I think => can be deleted!
308420 self .conv_shortcut = torch .nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
309421 else :
310422 self .nin_shortcut = torch .nn .Conv2d (in_channels , out_channels , kernel_size = 1 , stride = 1 , padding = 0 )
311423
312424 self .is_overwritten = False
313425 self .overwrite_for_grad_tts = overwrite_for_grad_tts
426+ self .overwrite_for_ldm = overwrite_for_ldm
314427 if self .overwrite_for_grad_tts :
315428 dim = in_channels
316429 dim_out = out_channels
@@ -324,6 +437,39 @@ def __init__(
324437 self .res_conv = torch .nn .Conv2d (dim , dim_out , 1 )
325438 else :
326439 self .res_conv = torch .nn .Identity ()
440+ elif self .overwrite_for_ldm :
441+ dims = 2
442+ # eps = 1e-5
443+ # non_linearity = "silu"
444+ # overwrite_for_ldm
445+ channels = in_channels
446+ emb_channels = temb_channels
447+ use_scale_shift_norm = False
448+
449+ self .in_layers = nn .Sequential (
450+ normalization (channels , swish = 1.0 ),
451+ nn .Identity (),
452+ conv_nd (dims , channels , self .out_channels , 3 , padding = 1 ),
453+ )
454+ self .emb_layers = nn .Sequential (
455+ nn .SiLU (),
456+ linear (
457+ emb_channels ,
458+ 2 * self .out_channels if use_scale_shift_norm else self .out_channels ,
459+ ),
460+ )
461+ self .out_layers = nn .Sequential (
462+ normalization (self .out_channels , swish = 0.0 if use_scale_shift_norm else 1.0 ),
463+ nn .SiLU () if use_scale_shift_norm else nn .Identity (),
464+ nn .Dropout (p = dropout ),
465+ zero_module (conv_nd (dims , self .out_channels , self .out_channels , 3 , padding = 1 )),
466+ )
467+ if self .out_channels == in_channels :
468+ self .skip_connection = nn .Identity ()
469+ # elif use_conv:
470+ # self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
471+ else :
472+ self .skip_connection = conv_nd (dims , channels , self .out_channels , 1 )
327473
328474 def set_weights_grad_tts (self ):
329475 self .conv1 .weight .data = self .block1 .block [0 ].weight .data
@@ -343,13 +489,36 @@ def set_weights_grad_tts(self):
343489 self .nin_shortcut .weight .data = self .res_conv .weight .data
344490 self .nin_shortcut .bias .data = self .res_conv .bias .data
345491
346- def forward (self , x , temb , mask = None ):
492+ def set_weights_ldm (self ):
493+ self .norm1 .weight .data = self .in_layers [0 ].weight .data
494+ self .norm1 .bias .data = self .in_layers [0 ].bias .data
495+
496+ self .conv1 .weight .data = self .in_layers [- 1 ].weight .data
497+ self .conv1 .bias .data = self .in_layers [- 1 ].bias .data
498+
499+ self .temb_proj .weight .data = self .emb_layers [- 1 ].weight .data
500+ self .temb_proj .bias .data = self .emb_layers [- 1 ].bias .data
501+
502+ self .norm2 .weight .data = self .out_layers [0 ].weight .data
503+ self .norm2 .bias .data = self .out_layers [0 ].bias .data
504+
505+ self .conv2 .weight .data = self .out_layers [- 1 ].weight .data
506+ self .conv2 .bias .data = self .out_layers [- 1 ].bias .data
507+
508+ if self .in_channels != self .out_channels :
509+ self .nin_shortcut .weight .data = self .skip_connection .weight .data
510+ self .nin_shortcut .bias .data = self .skip_connection .bias .data
511+
512+ def forward (self , x , temb , mask = 1.0 ):
347513 if self .overwrite_for_grad_tts and not self .is_overwritten :
348514 self .set_weights_grad_tts ()
349515 self .is_overwritten = True
516+ elif self .overwrite_for_ldm and not self .is_overwritten :
517+ self .set_weights_ldm ()
518+ self .is_overwritten = True
350519
351520 h = x
352- h = h * mask if mask is not None else h
521+ h = h * mask
353522 if self .pre_norm :
354523 h = self .norm1 (h )
355524 h = self .nonlinearity (h )
@@ -359,11 +528,11 @@ def forward(self, x, temb, mask=None):
359528 if not self .pre_norm :
360529 h = self .norm1 (h )
361530 h = self .nonlinearity (h )
362- h = h * mask if mask is not None else h
531+ h = h * mask
363532
364533 h = h + self .temb_proj (self .nonlinearity (temb ))[:, :, None , None ]
365534
366- h = h * mask if mask is not None else h
535+ h = h * mask
367536 if self .pre_norm :
368537 h = self .norm2 (h )
369538 h = self .nonlinearity (h )
@@ -374,9 +543,9 @@ def forward(self, x, temb, mask=None):
374543 if not self .pre_norm :
375544 h = self .norm2 (h )
376545 h = self .nonlinearity (h )
377- h = h * mask if mask is not None else h
546+ h = h * mask
378547
379- x = x * mask if mask is not None else x
548+ x = x * mask
380549 if self .in_channels != self .out_channels :
381550 if self .use_conv_shortcut :
382551 x = self .conv_shortcut (x )
0 commit comments