@@ -47,7 +47,7 @@ def __init__(
4747 if groups_out is None :
4848 groups_out = groups
4949
50- # there will always be at least one resenet
50+ # there will always be at least one resnet
5151 resnets = [ResidualTemporalBlock1D (in_channels , out_channels , embed_dim = temb_channels )]
5252
5353 for _ in range (num_layers ):
@@ -111,7 +111,7 @@ def __init__(
111111 if groups_out is None :
112112 groups_out = groups
113113
114- # there will always be at least one resenet
114+ # there will always be at least one resnet
115115 resnets = [ResidualTemporalBlock1D (2 * in_channels , out_channels , embed_dim = temb_channels )]
116116
117117 for _ in range (num_layers ):
@@ -174,22 +174,60 @@ class UpBlock1DNoSkip(nn.Module):
174174
175175
176176class MidResTemporalBlock1D (nn .Module ):
177- def __init__ (self , in_channels , out_channels , embed_dim , add_downsample ):
177+ def __init__ (
178+ self ,
179+ in_channels ,
180+ out_channels ,
181+ embed_dim ,
182+ num_layers : int = 1 ,
183+ add_downsample : bool = False ,
184+ add_upsample : bool = False ,
185+ non_linearity = None ,
186+ ):
178187 super ().__init__ ()
179188 self .in_channels = in_channels
180189 self .out_channels = out_channels
181190 self .add_downsample = add_downsample
182- self .resnet = ResidualTemporalBlock1D (in_channels , out_channels , embed_dim = embed_dim )
183191
192+ # there will always be at least one resnet
193+ resnets = [ResidualTemporalBlock1D (in_channels , out_channels , embed_dim = embed_dim )]
194+
195+ for _ in range (num_layers ):
196+ resnets .append (ResidualTemporalBlock1D (out_channels , out_channels , embed_dim = embed_dim ))
197+
198+ self .resnets = nn .ModuleList (resnets )
199+
200+ if non_linearity == "swish" :
201+ self .nonlinearity = lambda x : F .silu (x )
202+ elif non_linearity == "mish" :
203+ self .nonlinearity = nn .Mish ()
204+ elif non_linearity == "silu" :
205+ self .nonlinearity = nn .SiLU ()
206+ else :
207+ self .nonlinearity = None
208+
209+ self .upsample = None
210+ if add_downsample :
211+ self .upsample = Downsample1D (out_channels , use_conv = True )
212+
213+ self .downsample = None
184214 if add_downsample :
185215 self .downsample = Downsample1D (out_channels , use_conv = True )
186- else :
187- self .downsample = nn .Identity ()
188216
189- def forward (self , sample , temb ):
190- sample = self .resnet (sample , temb )
191- sample = self .downsample (sample )
192- return sample
217+ if self .upsample and self .downsample :
218+ raise ValueError ("Block cannot downsample and upsample" )
219+
220+ def forward (self , hidden_states , temb ):
221+ hidden_states = self .resnets [0 ](hidden_states , temb )
222+ for resnet in self .resnets [1 :]:
223+ hidden_states = resnet (hidden_states , temb )
224+
225+ if self .upsample :
226+ hidden_states = self .upsample (hidden_states )
227+ if self .downsample :
228+ self .downsample = self .downsample (hidden_states )
229+
230+ return hidden_states
193231
194232
195233class OutConv1DBlock (nn .Module ):
@@ -203,14 +241,14 @@ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
203241 self .final_conv1d_act = nn .Mish ()
204242 self .final_conv1d_2 = nn .Conv1d (embed_dim , out_channels , 1 )
205243
206- def forward (self , sample , t ):
207- sample = self .final_conv1d_1 (sample )
208- sample = rearrange_dims (sample )
209- sample = self .final_conv1d_gn (sample )
210- sample = rearrange_dims (sample )
211- sample = self .final_conv1d_act (sample )
212- sample = self .final_conv1d_2 (sample )
213- return sample
244+ def forward (self , hidden_states , temb = None ):
245+ hidden_states = self .final_conv1d_1 (hidden_states )
246+ hidden_states = rearrange_dims (hidden_states )
247+ hidden_states = self .final_conv1d_gn (hidden_states )
248+ hidden_states = rearrange_dims (hidden_states )
249+ hidden_states = self .final_conv1d_act (hidden_states )
250+ hidden_states = self .final_conv1d_2 (hidden_states )
251+ return hidden_states
214252
215253
216254class OutValueFunctionBlock (nn .Module ):
@@ -224,13 +262,13 @@ def __init__(self, fc_dim, embed_dim):
224262 ]
225263 )
226264
227- def forward (self , sample , t ):
228- sample = sample .view (sample .shape [0 ], - 1 )
229- sample = torch .cat ((sample , t ), dim = - 1 )
265+ def forward (self , hidden_states , temb ):
266+ hidden_states = hidden_states .view (hidden_states .shape [0 ], - 1 )
267+ hidden_states = torch .cat ((hidden_states , temb ), dim = - 1 )
230268 for layer in self .final_block :
231- sample = layer (sample )
269+ hidden_states = layer (hidden_states )
232270
233- return sample
271+ return hidden_states
234272
235273
236274def get_down_block (down_block_type , num_layers , in_channels , out_channels , temb_channels , add_downsample ):
@@ -260,9 +298,15 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan
260298 raise ValueError (f"{ up_block_type } does not exist." )
261299
262300
263- def get_mid_block (mid_block_type , in_channels , out_channels , embed_dim , add_downsample ):
301+ def get_mid_block (mid_block_type , num_layers , in_channels , out_channels , embed_dim , add_downsample ):
264302 if mid_block_type == "MidResTemporalBlock1D" :
265- return MidResTemporalBlock1D (in_channels , out_channels , embed_dim , add_downsample )
303+ return MidResTemporalBlock1D (
304+ num_layers = num_layers ,
305+ in_channels = in_channels ,
306+ out_channels = out_channels ,
307+ embed_dim = embed_dim ,
308+ add_downsample = add_downsample ,
309+ )
266310 raise ValueError (f"{ mid_block_type } does not exist." )
267311
268312
0 commit comments