Skip to content

Commit

Permalink
fix: bring positional args back, add recursive to blip etc
Browse files Browse the repository at this point in the history
  • Loading branch information
Linaqruf committed Apr 11, 2023
1 parent bf8088e commit c316c63
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 16 deletions.
7 changes: 5 additions & 2 deletions finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import random

from pathlib import Path
from PIL import Image
from tqdm import tqdm
import numpy as np
Expand Down Expand Up @@ -72,7 +73,8 @@ def main(args):
os.chdir('finetune')

print(f"load images from {args.train_data_dir}")
image_paths = train_util.glob_images(args.train_data_dir)
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")

print(f"loading BLIP caption: {args.caption_weights}")
Expand Down Expand Up @@ -152,7 +154,8 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
parser.add_argument("--debug", action="store_true", help="debug mode")

parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively")

return parser


Expand Down
7 changes: 5 additions & 2 deletions finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import re

from pathlib import Path
from PIL import Image
from tqdm import tqdm
import torch
Expand Down Expand Up @@ -65,7 +66,8 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs)
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch

print(f"load images from {args.train_data_dir}")
image_paths = train_util.glob_images(args.train_data_dir)
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")

# できればcacheに依存せず明示的にダウンロードしたい
Expand Down Expand Up @@ -140,7 +142,8 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--remove_words", action="store_true",
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
parser.add_argument("--debug", action="store_true", help="debug mode")

parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively")

return parser


Expand Down
29 changes: 21 additions & 8 deletions finetune/prepare_buckets_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import json

from pathlib import Path
from typing import List
from tqdm import tqdm
import numpy as np
from PIL import Image
Expand Down Expand Up @@ -41,22 +43,31 @@ def get_latents(vae, images, weight_dtype):
return latents


def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
else:
base_name = image_key
relative_path = ""

if flip:
base_name += '_flip'
return os.path.join(data_dir, base_name)

if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name)
else:
return os.path.join(data_dir, base_name)



def main(args):
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")

image_paths = train_util.glob_images(args.train_data_dir)
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
print(f"found {len(image_paths)} images.")

if os.path.exists(args.in_json):
Expand Down Expand Up @@ -99,20 +110,20 @@ def process_batch(is_last):
f"latent shape {latents.shape}, {bucket[0][1].shape}"

for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
np.savez(npz_file_name, latent)

# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない

for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive)
np.savez(npz_file_name, latent)
else:
# remove existing flipped npz
for image_key, _ in bucket:
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
if os.path.isfile(npz_file_name):
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
os.remove(npz_file_name)
Expand Down Expand Up @@ -169,9 +180,9 @@ def process_batch(is_last):

# 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
if args.flip_aug:
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz")

found = True
for npz_file in npz_files:
Expand Down Expand Up @@ -256,6 +267,8 @@ def setup_parser() -> argparse.ArgumentParser:
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")

return parser

Expand Down
8 changes: 4 additions & 4 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
import pathlib
from pathlib import Path

import library.train_util as train_util

Expand Down Expand Up @@ -103,8 +103,8 @@ def main(args):

# 画像を読み込む

train_data_dir = pathlib.Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir, args.recursive)
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")

tag_freq = {}
Expand Down Expand Up @@ -205,7 +205,7 @@ def run_batch(path_imgs):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
Expand Down

0 comments on commit c316c63

Please sign in to comment.