Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add masked loss implementation #589

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
get_latent_masks
)


Expand Down Expand Up @@ -339,6 +340,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
Expand Down
21 changes: 21 additions & 0 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,27 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise


def get_latent_masks(image_masks, latent_shape, device):
# given that masks lower the average loss this will counteract the effect
factor = torch.sqrt(image_masks.mean([1, 2]))
factor = torch.where(factor != 0.0, factor, 1.0)
factor = factor.reshape(factor.shape + (1,) * 2)
image_masks = image_masks / factor

masks = (
image_masks
.to(device)
.reshape(latent_shape[0], 1, latent_shape[2] * 8, latent_shape[3] * 8)
)
# resize to match latent
masks = torch.nn.functional.interpolate(
masks.float(),
size=latent_shape[-2:],
mode="nearest"
)
return masks


"""
##########################################
# Perlin Noise
Expand Down
64 changes: 60 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool,
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
# Masked Loss
self.mask: np.ndarray = None
self.mask_flipped: np.ndarray = None


class BucketManager:
Expand Down Expand Up @@ -1050,6 +1053,7 @@ def __getitem__(self, index):
input_ids2_list = []
latents_list = []
images = []
masks = []
original_sizes_hw = []
crop_top_lefts = []
target_sizes_hw = []
Expand All @@ -1071,14 +1075,18 @@ def __getitem__(self, index):
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
if not flipped:
latents = image_info.latents
mask = image_info.mask
else:
latents = image_info.latents_flipped
mask = image_info.mask_flipped

image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz)
mask = load_mask(image_info.absolute_path, image_info.resized_size) / 225
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be 255?

Of note: loading masks breaks when loading a latent that was cached to disk if the original image was resized and cropped, because the loaded mask doesn't undergo the same load/crop routine, and ends up mismatched to the latent. When loading latents from disk, the original image is never considered. I think that the current implementation of trim_and_resize_if_required prevents replication of the original crop if random_crop is specified, since the coordinates of the crop are not passed back; crop_ltrb is something other than the coords actually used for cropping the image.

I've solved this in my local copy by just persisting the mask in the .npz, but that obviously blows up the cache size.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops! nice catch.

I guess right now it is a limitation - I don't see a reliable way of preserving the crop parameters other than serializing the mask like you did (or a latent sized crop), or at least embed the crop parameters into the npz.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping, returning, persisting, and retrieving a second set of crop parameters from trim_and_resize_if_required would be the much lighter option. Either way, a change to the npz format is likely necessary, or latent disk caching should fail loudly when using masked loss and images which need cropping.

if flipped:
latents = flipped_latents
mask = mask.flip(mask, dims=[3])
recris marked this conversation as resolved.
Show resolved Hide resolved
del flipped_latents
latents = torch.FloatTensor(latents)

Expand Down Expand Up @@ -1122,11 +1130,16 @@ def __getitem__(self, index):
if flipped:
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem

# loss mask is alpha channel, separate it
mask = img[:, :, -1] / 255
img = img[:, :, :3]

latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる

images.append(image)
latents_list.append(latents)
masks.append(torch.tensor(mask))

target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)

Expand Down Expand Up @@ -1218,7 +1231,7 @@ def __getitem__(self, index):
else:
images = None
example["images"] = images

example["masks"] = torch.stack(masks) if masks[0] is not None else None
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions

Expand Down Expand Up @@ -2132,12 +2145,44 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:

def load_image(image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == "RGBA":
image = image.convert("RGBA")
img = np.array(image, np.uint8)
img[..., -1] = load_mask(image_path, img.shape[:2])
return img


def load_mask(image_path, target_shape):
p = pathlib.Path(image_path)
mask_path = os.path.join(p.parent, 'mask', p.stem + '.png')
result = None

if os.path.exists(mask_path):
try:
mask_img = Image.open(mask_path)
mask = np.array(mask_img)
if len(mask.shape) > 2 and mask.max() <= 255:
result = np.array(mask_img.convert("L"))
elif len(mask.shape) == 2 and mask.max() > 255:
result = mask // (((2 ** 16) - 1) // 255)
elif len(mask.shape) == 2 and mask.max() <= 255:
result = mask
else:
print(f"{mask_path} has invalid mask format: using default mask")
except:
print(f"failed to load mask: {mask_path}")

# use default when mask file is unavailable
if result is None:
result = np.full(target_shape, 255, np.uint8)

# stretch mask to image shape
if result.shape != target_shape:
result = cv2.resize(result, dsize=target_shape, interpolation=cv2.INTER_LINEAR)

return result


# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
def trim_and_resize_if_required(
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
Expand Down Expand Up @@ -2184,12 +2229,17 @@ def cache_batch_latents(
latents_original_size and latents_crop_ltrb are also set
"""
images = []
masks = []
for info in image_infos:
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
# alpha channel contains loss mask, separate it
mask = image[:, :, -1] / 255
image = image[:, :, :3]
image = IMAGE_TRANSFORMS(image)
images.append(image)
masks.append(mask)

info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
Expand All @@ -2207,7 +2257,7 @@ def cache_batch_latents(
else:
flipped_latents = [None] * len(latents)

for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
for info, latent, flipped_latent, mask in zip(image_infos, latents, flipped_latents, masks):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
Expand All @@ -2216,8 +2266,10 @@ def cache_batch_latents(
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
else:
info.latents = latent
info.mask = mask
if flip_aug:
info.latents_flipped = flipped_latent
info.mask_flipped = mask.flip(mask, dims=[3])

# FIXME this slows down caching a lot, specify this as an option
if torch.cuda.is_available():
Expand Down Expand Up @@ -3159,6 +3211,10 @@ def add_dataset_arguments(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
)

parser.add_argument(
"--masked_loss", action="store_true", help="Enable masking of latent loss using grayscale mask images"
)

parser.add_argument(
"--token_warmup_min",
type=int,
Expand Down
6 changes: 6 additions & 0 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
get_latent_masks
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel

Expand Down Expand Up @@ -548,6 +549,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
Expand Down
6 changes: 6 additions & 0 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
get_latent_masks
)

# perlin_noise,
Expand Down Expand Up @@ -326,6 +327,11 @@ def train(args):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down
6 changes: 6 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
get_latent_masks
)


Expand Down Expand Up @@ -796,6 +797,11 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down
6 changes: 6 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
get_latent_masks
)

imagenet_templates_small = [
Expand Down Expand Up @@ -570,6 +571,11 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down
6 changes: 6 additions & 0 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
get_latent_masks
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
Expand Down Expand Up @@ -461,6 +462,11 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down