Skip to content

Commit 12e787c

Browse files
committed
Add Minimal Implementation of Masked Weight Loss
1 parent f0ae7ee commit 12e787c

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

library/train_util.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool,
6868
self.latents_flipped: torch.Tensor = None
6969
self.latents_npz: str = None
7070
self.latents_npz_flipped: str = None
71+
self.mask: np.ndarray = None
72+
self.mask_flipped: np.ndarray = None
7173

7274

7375
class BucketManager():
@@ -467,9 +469,12 @@ def shuffle_buckets(self):
467469

468470
def load_image(self, image_path):
469471
image = Image.open(image_path)
470-
if not image.mode == "RGB":
471-
image = image.convert("RGB")
472+
# if not image.mode == "RGB":
473+
# image = image.convert("RGB")
474+
if not image.mode == "RGBA":
475+
image = image.convert("RGBA")
472476
img = np.array(image, np.uint8)
477+
# alpha_channel = np.array(image, np.uint8)[:,:,-1]
473478
return img
474479

475480
def trim_and_resize_if_required(self, image, reso, resized_size):
@@ -508,16 +513,19 @@ def cache_latents(self, vae):
508513

509514
image = self.load_image(info.absolute_path)
510515
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
511-
516+
mask = image[:,:,-1] #grab alpha channel
517+
image = image[:,:,:3] #drop alpha channel
512518
img_tensor = self.image_transforms(image)
513519
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
514520
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
521+
info.mask = mask/255
515522

516523
if self.flip_aug:
517524
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
518525
img_tensor = self.image_transforms(image)
519526
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
520527
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
528+
info.mask_flipped = mask[::-1]/255
521529

522530
def get_image_size(self, image_path):
523531
image = Image.open(image_path)
@@ -606,14 +614,17 @@ def __getitem__(self, index):
606614
input_ids_list = []
607615
latents_list = []
608616
images = []
617+
masks = []
609618

610619
for image_key in bucket[image_index:image_index + bucket_batch_size]:
611620
image_info = self.image_data[image_key]
612621
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
613622

614623
# image/latentsを処理する
615624
if image_info.latents is not None:
616-
latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
625+
rand_flip = random.random()
626+
latents = image_info.latents if not self.flip_aug or rand_flip < .5 else image_info.latents_flipped
627+
mask = image_info.mask if not self.flip_aug or rand_flip < .5 else image_info.mask_flipped
617628
image = None
618629
elif image_info.latents_npz is not None:
619630
latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
@@ -622,6 +633,8 @@ def __getitem__(self, index):
622633
else:
623634
# 画像を読み込み、必要ならcropする
624635
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
636+
mask = img[:,:,-1] #grab alpha channel
637+
img = img[:,:,:3] #drop alpha channel
625638
im_h, im_w = img.shape[0:2]
626639

627640
if self.enable_bucket:
@@ -647,7 +660,8 @@ def __getitem__(self, index):
647660

648661
latents = None
649662
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
650-
663+
mask = (self.image_transforms(mask)+1)*.5
664+
masks.append(torch.tensor(mask))
651665
images.append(image)
652666
latents_list.append(latents)
653667

@@ -672,7 +686,7 @@ def __getitem__(self, index):
672686
else:
673687
images = None
674688
example['images'] = images
675-
689+
example['masks'] = torch.stack(masks) if masks[0] is not None else None
676690
example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
677691

678692
if self.debug_dataset:
@@ -1494,6 +1508,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
14941508
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
14951509
parser.add_argument("--bucket_no_upscale", action="store_true",
14961510
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
1511+
parser.add_argument("--masked_loss", action="store_true",
1512+
help="Enable Masked Loss from Alpha Channel")
14971513

14981514
if support_caption_dropout:
14991515
# Textual Inversion はcaptionのdropoutをsupportしない
@@ -2059,4 +2075,4 @@ def __getitem__(self, idx):
20592075
return (tensor_pil, img_path)
20602076

20612077

2062-
# endregion
2078+
# endregion

train_network.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from torch.cuda.amp import autocast
22
from torch.nn.parallel import DistributedDataParallel as DDP
3+
import torch.nn.functional
34
import importlib
45
import argparse
56
import gc
@@ -377,6 +378,24 @@ def train(args):
377378
target = noise_scheduler.get_velocity(latents, noise, timesteps)
378379
else:
379380
target = noise
381+
382+
if args.masked_loss and batch['masks'] is not None:
383+
384+
mask = (
385+
batch['masks']
386+
.to(noise_pred.device)
387+
.reshape(
388+
noise_pred.shape[0], 1, noise_pred.shape[2] * 8, noise_pred.shape[3] * 8
389+
)
390+
)
391+
# resize to match model_pred
392+
mask = torch.nn.functional.interpolate(
393+
mask.float(),
394+
size=noise_pred.shape[-2:],
395+
mode="nearest",
396+
)
397+
noise_pred = noise_pred * mask
398+
target = target * mask
380399

381400
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
382401
loss = loss.mean([1, 2, 3])

0 commit comments

Comments
 (0)