diff --git a/stylegan/prepare_data.py b/stylegan/prepare_data.py old mode 100644 new mode 100755 index 739e2e5..6f7f829 --- a/stylegan/prepare_data.py +++ b/stylegan/prepare_data.py @@ -10,36 +10,40 @@ from torchvision.transforms import functional as trans_fn -def resize_and_convert(img, size, quality=100): - img = trans_fn.resize(img, size, Image.LANCZOS) +def resize_and_convert(img, size, resample, quality=100): + img = trans_fn.resize(img, size, resample) img = trans_fn.center_crop(img, size) buffer = BytesIO() - img.save(buffer, format='jpeg', quality=quality) + img.save(buffer, format="jpeg", quality=quality) val = buffer.getvalue() return val -def resize_multiple(img, sizes=(8, 16, 32, 64, 128, 256, 512, 1024), quality=100): +def resize_multiple( + img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100 +): imgs = [] for size in sizes: - imgs.append(resize_and_convert(img, size, quality)) + imgs.append(resize_and_convert(img, size, resample, quality)) return imgs -def resize_worker(img_file, sizes): +def resize_worker(img_file, sizes, resample): i, file = img_file img = Image.open(file) - img = img.convert('RGB') - out = resize_multiple(img, sizes=sizes) + img = img.convert("RGB") + out = resize_multiple(img, sizes=sizes, resample=resample) return i, out -def prepare(transaction, dataset, n_worker, sizes=(8, 16, 32, 64, 128, 256, 512, 1024)): - resize_fn = partial(resize_worker, sizes=sizes) +def prepare( + env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS +): + resize_fn = partial(resize_worker, sizes=sizes, resample=resample) files = sorted(dataset.imgs, key=lambda x: x[0]) files = [(i, file) for i, (file, label) in enumerate(files)] @@ -48,23 +52,50 @@ def prepare(transaction, dataset, n_worker, sizes=(8, 16, 32, 64, 128, 256, 512, with multiprocessing.Pool(n_worker) as pool: for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): for size, img in zip(sizes, imgs): - key = f'{size}-{str(i).zfill(5)}'.encode('utf-8') - transaction.put(key, img) + key = f"{size}-{str(i).zfill(5)}".encode("utf-8") + + with env.begin(write=True) as txn: + txn.put(key, img) total += 1 - transaction.put('length'.encode('utf-8'), str(total).encode('utf-8')) + with env.begin(write=True) as txn: + txn.put("length".encode("utf-8"), str(total).encode("utf-8")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Preprocess images for model training") + parser.add_argument("--out", type=str, help="filename of the result lmdb dataset") + parser.add_argument( + "--size", + type=str, + default="128,256,512,1024", + help="resolutions of images for the dataset", + ) + parser.add_argument( + "--n_worker", + type=int, + default=8, + help="number of workers for preparing dataset", + ) + parser.add_argument( + "--resample", + type=str, + default="lanczos", + help="resampling methods for resizing images", + ) + parser.add_argument("path", type=str, help="path to the image dataset") + + args = parser.parse_args() + resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} + resample = resample_map[args.resample] -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--dataset', type=str, required=True, help='dataset name') - parser.add_argument('--n_worker', type=int, default=8) + sizes = [int(s.strip()) for s in args.size.split(",")] - args = parser.parse_args() + print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) - imgset = datasets.ImageFolder(f'./dataset/{args.dataset}') + imgset = datasets.ImageFolder(args.path) - with lmdb.open(f'./dataset/{args.dataset}_lmdb', map_size=1024 ** 4, readahead=False) as env: - with env.begin(write=True) as txn: - prepare(txn, imgset, args.n_worker) + with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: + prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)