5
5
import random
6
6
import time
7
7
from multiprocessing import Value
8
- from types import SimpleNamespace
8
+ from omegaconf import OmegaConf
9
9
import toml
10
10
11
11
from tqdm import tqdm
@@ -148,8 +148,10 @@ def train(args):
148
148
"in_channels" : 4 ,
149
149
"layers_per_block" : 2 ,
150
150
"mid_block_scale_factor" : 1 ,
151
+ "mid_block_type" : "UNetMidBlock2DCrossAttn" ,
151
152
"norm_eps" : 1e-05 ,
152
153
"norm_num_groups" : 32 ,
154
+ "num_attention_heads" : [5 , 10 , 20 , 20 ],
153
155
"num_class_embeds" : None ,
154
156
"only_cross_attention" : False ,
155
157
"out_channels" : 4 ,
@@ -179,8 +181,10 @@ def train(args):
179
181
"in_channels" : 4 ,
180
182
"layers_per_block" : 2 ,
181
183
"mid_block_scale_factor" : 1 ,
184
+ "mid_block_type" : "UNetMidBlock2DCrossAttn" ,
182
185
"norm_eps" : 1e-05 ,
183
186
"norm_num_groups" : 32 ,
187
+ "num_attention_heads" : 8 ,
184
188
"out_channels" : 4 ,
185
189
"sample_size" : 64 ,
186
190
"up_block_types" : ["UpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" ],
@@ -193,7 +197,7 @@ def train(args):
193
197
"resnet_time_scale_shift" : "default" ,
194
198
"projection_class_embeddings_input_dim" : None ,
195
199
}
196
- unet .config = SimpleNamespace ( ** unet .config )
200
+ unet .config = OmegaConf . create ( unet .config )
197
201
198
202
controlnet = ControlNetModel .from_unet (unet )
199
203
0 commit comments