Skip to content

Commit

Permalink
fix bucketing
Browse files Browse the repository at this point in the history
  • Loading branch information
ddPn08 committed Jun 1, 2023
1 parent 3bd00b8 commit 1e3daa2
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 63 deletions.
55 changes: 24 additions & 31 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,28 +754,41 @@ def load_image(self, image_path):
img = np.array(image, np.uint8)
return img

def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size, cond_img = None):
image_height, image_width = image.shape[0:2]

if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
if exists(cond_img):
cond_img = cv2.resize(cond_img, resized_size, interpolation=cv2.INTER_AREA)

image_height, image_width = image.shape[0:2]
if image_width > reso[0]:
trim_size = image_width - reso[0]
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
# print("w", trim_size, p)
image = image[:, p : p + reso[0]]
if exists(cond_img):
cond_img = cond_img[:, p : p + reso[0]]
if image_height > reso[1]:
trim_size = image_height - reso[1]
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
# print("h", trim_size, p)
image = image[p : p + reso[1]]
if exists(cond_img):
cond_img = cond_img[p : p + reso[1]]

assert (
image.shape[0] == reso[1] and image.shape[1] == reso[0]
), f"internal error, illegal trimmed size: {image.shape}, {reso}"

if exists(cond_img):
assert (
cond_img.shape[0] == reso[1] and cond_img.shape[1] == reso[0]
), f"internal error, illegal trimmed size: {cond_img.shape}, {reso}"
return image, cond_img

return image

def is_latent_cacheable(self):
Expand Down Expand Up @@ -1617,6 +1630,8 @@ def __getitem__(self, index):
subset = self.image_to_subset[image_key]
loss_weights.append(1.0)

assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"

# image/latentsを処理する
if image_info.latents is not None: # cache_latents=Trueの場合
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
Expand All @@ -1628,10 +1643,11 @@ def __getitem__(self, index):
else:
# 画像を読み込み、必要ならcropする
img = self.load_image(image_info.absolute_path)
cond_img = self.load_image(image_info.cond_img_path)
im_h, im_w = img.shape[0:2]

if self.enable_bucket:
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
img, cond_img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size, cond_img=cond_img)
else:
im_h, im_w = img.shape[0:2]
assert (
Expand All @@ -1649,41 +1665,18 @@ def __getitem__(self, index):
images.append(image)
latents_list.append(latents)

caption = self.process_caption(subset, image_info.caption)
if self.XTI_layers:
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
else:
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer)
else:
token_caption = self.get_input_ids(caption)
input_ids_list.append(token_caption)

assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"

cond_img = self.load_image(image_info.cond_img_path)
if self.enable_bucket:
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
cond_img = self.conditioning_image_transforms(cond_img)
conditioning_images.append(cond_img)

caption = self.process_caption(subset, image_info.caption)
captions.append(caption)
token_caption = self.get_input_ids(caption)
input_ids_list.append(token_caption)

example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)

if self.token_padding_disabled:
# padding=True means pad in the batch
example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
else:
# batch processing seems to be good
example["input_ids"] = torch.stack(input_ids_list)
example["input_ids"] = torch.stack(input_ids_list)

if images[0] is not None:
images = torch.stack(images)
Expand Down
60 changes: 28 additions & 32 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def train(args):
controlnet = ControlNetModel.from_pretrained(filename)



# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)

Expand All @@ -168,11 +167,11 @@ def train(args):
controlnet.enable_gradient_checkpointing()

# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
accelerator.print("prepare optimizer, data loader etc.")

trainable_params = controlnet.parameters()

optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(
_, _, optimizer = train_util.get_optimizer(
args, trainable_params
)

Expand All @@ -198,10 +197,9 @@ def train(args):
/ accelerator.num_processes
/ args.gradient_accumulation_steps
)
if is_main_process:
print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)

# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
Expand All @@ -216,7 +214,7 @@ def train(args):
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
accelerator.print("enable full fp16 training.")
controlnet.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
Expand Down Expand Up @@ -258,23 +256,21 @@ def train(args):

# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets

if is_main_process:
print("running training / 学習開始")
print(
f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}"
)
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
)
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
accelerator.print("running training / 学習開始")
accelerator.print(
f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}"
)
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
)
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")

progress_bar = tqdm(
range(args.max_train_steps),
Expand Down Expand Up @@ -303,11 +299,11 @@ def train(args):
del train_dataset_group

# function for saving/removing
def save_model(ckpt_name, model, steps, epoch_no, force_sync_upload=False):
def save_model(ckpt_name, model, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)

print(f"\nsaving checkpoint: {ckpt_file}")
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")

state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())

Expand All @@ -332,13 +328,13 @@ def save_model(ckpt_name, model, steps, epoch_no, force_sync_upload=False):
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)

# training loop
for epoch in range(num_train_epochs):
if is_main_process:
print(f"\nepoch {epoch+1}/{num_train_epochs}")
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

for step, batch in enumerate(train_dataloader):
Expand Down Expand Up @@ -470,7 +466,7 @@ def remove_model(old_ckpt_name):
args, "." + args.save_model_as, global_step
)
save_model(
ckpt_name, unwrap_model(controlnet), global_step, epoch
ckpt_name, unwrap_model(controlnet),
)

if args.save_state:
Expand Down Expand Up @@ -520,7 +516,7 @@ def remove_model(old_ckpt_name):
ckpt_name = train_util.get_epoch_ckpt_name(
args, "." + args.save_model_as, epoch + 1
)
save_model(ckpt_name, unwrap_model(controlnet), global_step, epoch + 1)
save_model(ckpt_name, unwrap_model(controlnet))

remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
Expand Down Expand Up @@ -561,7 +557,7 @@ def remove_model(old_ckpt_name):
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(
ckpt_name, controlnet, global_step, num_train_epochs, force_sync_upload=True
ckpt_name, controlnet, force_sync_upload=True
)

print("model saved.")
Expand Down

0 comments on commit 1e3daa2

Please sign in to comment.