Skip to content

Commit 26ce60c

Browse files
up
1 parent 358531b commit 26ce60c

File tree

2 files changed

+49
-48
lines changed

2 files changed

+49
-48
lines changed

src/diffusers/models/resnet.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ def forward(self, x, emb):
330330

331331
result = self.skip_connection(x) + h
332332

333-
# TODO(Patrick) Use for glide at later stage
334-
# result = self.forward_2(x, emb)
333+
# TODO(Patrick) Use for glide at later stage
334+
# result = self.forward_2(x, emb)
335335

336336
return result
337337

@@ -439,9 +439,9 @@ def __init__(
439439
self.res_conv = torch.nn.Identity()
440440
elif self.overwrite_for_ldm:
441441
dims = 2
442-
# eps = 1e-5
443-
# non_linearity = "silu"
444-
# overwrite_for_ldm
442+
# eps = 1e-5
443+
# non_linearity = "silu"
444+
# overwrite_for_ldm
445445
channels = in_channels
446446
emb_channels = temb_channels
447447
use_scale_shift_norm = False
@@ -466,8 +466,8 @@ def __init__(
466466
)
467467
if self.out_channels == in_channels:
468468
self.skip_connection = nn.Identity()
469-
# elif use_conv:
470-
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
469+
# elif use_conv:
470+
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
471471
else:
472472
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
473473

src/diffusers/models/unet_ldm.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from ..modeling_utils import ModelMixin
1111
from .attention import AttentionBlock
1212
from .embeddings import get_timestep_embedding
13-
from .resnet import Downsample, TimestepBlock, Upsample
14-
from .resnet import ResnetBlock
15-
#from .resnet import ResBlock
13+
from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
14+
15+
16+
# from .resnet import ResBlock
1617

1718

1819
def exists(val):
@@ -561,14 +562,14 @@ def __init__(
561562
for level, mult in enumerate(channel_mult):
562563
for _ in range(num_res_blocks):
563564
layers = [
564-
ResnetBlock(
565-
in_channels=ch,
566-
out_channels=mult * model_channels,
567-
dropout=dropout,
568-
temb_channels=time_embed_dim,
569-
eps=1e-5,
570-
non_linearity="silu",
571-
overwrite_for_ldm=True,
565+
ResnetBlock(
566+
in_channels=ch,
567+
out_channels=mult * model_channels,
568+
dropout=dropout,
569+
temb_channels=time_embed_dim,
570+
eps=1e-5,
571+
non_linearity="silu",
572+
overwrite_for_ldm=True,
572573
)
573574
]
574575
ch = mult * model_channels
@@ -601,16 +602,16 @@ def __init__(
601602
out_ch = ch
602603
self.input_blocks.append(
603604
TimestepEmbedSequential(
604-
# ResBlock(
605-
# ch,
606-
# time_embed_dim,
607-
# dropout,
608-
# out_channels=out_ch,
609-
# dims=dims,
610-
# use_checkpoint=use_checkpoint,
611-
# use_scale_shift_norm=use_scale_shift_norm,
612-
# down=True,
613-
# )
605+
# ResBlock(
606+
# ch,
607+
# time_embed_dim,
608+
# dropout,
609+
# out_channels=out_ch,
610+
# dims=dims,
611+
# use_checkpoint=use_checkpoint,
612+
# use_scale_shift_norm=use_scale_shift_norm,
613+
# down=True,
614+
# )
614615
None
615616
if resblock_updown
616617
else Downsample(
@@ -703,16 +704,16 @@ def __init__(
703704
if level and i == num_res_blocks:
704705
out_ch = ch
705706
layers.append(
706-
# ResBlock(
707-
# ch,
708-
# time_embed_dim,
709-
# dropout,
710-
# out_channels=out_ch,
711-
# dims=dims,
712-
# use_checkpoint=use_checkpoint,
713-
# use_scale_shift_norm=use_scale_shift_norm,
714-
# up=True,
715-
# )
707+
# ResBlock(
708+
# ch,
709+
# time_embed_dim,
710+
# dropout,
711+
# out_channels=out_ch,
712+
# dims=dims,
713+
# use_checkpoint=use_checkpoint,
714+
# use_scale_shift_norm=use_scale_shift_norm,
715+
# up=True,
716+
# )
716717
None
717718
if resblock_updown
718719
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
@@ -876,16 +877,16 @@ def __init__(
876877
out_ch = ch
877878
self.input_blocks.append(
878879
TimestepEmbedSequential(
879-
# ResBlock(
880-
# ch,
881-
# time_embed_dim,
882-
# dropout,
883-
# out_channels=out_ch,
884-
# dims=dims,
885-
# use_checkpoint=use_checkpoint,
886-
# use_scale_shift_norm=use_scale_shift_norm,
887-
# down=True,
888-
# )
880+
# ResBlock(
881+
# ch,
882+
# time_embed_dim,
883+
# dropout,
884+
# out_channels=out_ch,
885+
# dims=dims,
886+
# use_checkpoint=use_checkpoint,
887+
# use_scale_shift_norm=use_scale_shift_norm,
888+
# down=True,
889+
# )
889890
None
890891
if resblock_updown
891892
else Downsample(

0 commit comments

Comments
 (0)