Skip to content

Commit de0e0b9

Browse files
authored
Merge pull request kohya-ss#1284 from sdbds/fix_traincontrolnet
Fix train controlnet
2 parents c68baae + 5cb145d commit de0e0b9

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ easygui==0.98.3
1717
toml==0.10.2
1818
voluptuous==0.13.1
1919
huggingface-hub==0.20.1
20+
omegaconf==2.3.0
2021
# for Image utils
2122
imagesize==1.4.1
2223
# for BLIP captioning

train_controlnet.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import random
66
import time
77
from multiprocessing import Value
8-
from types import SimpleNamespace
8+
from omegaconf import OmegaConf
99
import toml
1010

1111
from tqdm import tqdm
@@ -148,8 +148,10 @@ def train(args):
148148
"in_channels": 4,
149149
"layers_per_block": 2,
150150
"mid_block_scale_factor": 1,
151+
"mid_block_type": "UNetMidBlock2DCrossAttn",
151152
"norm_eps": 1e-05,
152153
"norm_num_groups": 32,
154+
"num_attention_heads": [5, 10, 20, 20],
153155
"num_class_embeds": None,
154156
"only_cross_attention": False,
155157
"out_channels": 4,
@@ -179,8 +181,10 @@ def train(args):
179181
"in_channels": 4,
180182
"layers_per_block": 2,
181183
"mid_block_scale_factor": 1,
184+
"mid_block_type": "UNetMidBlock2DCrossAttn",
182185
"norm_eps": 1e-05,
183186
"norm_num_groups": 32,
187+
"num_attention_heads": 8,
184188
"out_channels": 4,
185189
"sample_size": 64,
186190
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
@@ -193,7 +197,7 @@ def train(args):
193197
"resnet_time_scale_shift": "default",
194198
"projection_class_embeddings_input_dim": None,
195199
}
196-
unet.config = SimpleNamespace(**unet.config)
200+
unet.config = OmegaConf.create(unet.config)
197201

198202
controlnet = ControlNetModel.from_unet(unet)
199203

0 commit comments

Comments
 (0)