-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathlmdb2wds.py
39 lines (32 loc) · 1.86 KB
/
lmdb2wds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
'''
This file implements the direction conversion from the latent ImageNet dataset to WebDataset.
'''
import os
from argparse import ArgumentParser
from tqdm import tqdm
import numpy as np
import pickle
import webdataset as wds
from train_utils.datasets import ImageNetLatentDataset
def convert2wds(args):
os.makedirs(args.outdir, exist_ok=True)
wds_path = os.path.join(args.outdir, f'latent_imagenet_512_{args.split}-%04d.tar')
dataset = ImageNetLatentDataset(args.datadir, resolution=args.resolution, num_channels=args.num_channels, split=args.split)
with wds.ShardWriter(wds_path, maxcount=args.maxcount, maxsize=args.maxsize) as sink:
for i in tqdm(range(len(dataset)), dynamic_ncols=True):
if i % args.maxcount == 0:
print(f'writing to the {i // args.maxcount}th shard')
img, label = dataset[i] # C, H, W
label = np.argmax(label) # int
sink.write({'__key__': f'{i:07d}', 'latent': pickle.dumps(img), 'cls': label})
if __name__ == "__main__":
parser = ArgumentParser('Convert the latent imagenet dataset to WebDataset')
parser.add_argument('--maxcount', type=int, default=10010, help='max number of entries per shard')
parser.add_argument('--maxsize', type=int, default=10 ** 10, help='max size per shard')
parser.add_argument('--outdir', type=str, default='latent_imagenet_wds', help='path to save the converted dataset')
parser.add_argument('--datadir', type=str, default='latent_imagenet', help='path to the latent imagenet dataset')
parser.add_argument('--resolution', type=int, default=64, help='image resolution')
parser.add_argument('--num_channels', type=int, default=8, help='number of image channels')
parser.add_argument('--split', type=str, default='train', help='split of the dataset')
args = parser.parse_args()
convert2wds(args)