-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathinference.py
69 lines (60 loc) · 2.3 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from argparse import ArgumentParser
import torch
from mmengine.registry import init_default_scope
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
from vton_dataset import AlignedDataset
from mmagic.apis.inferencers.inference_functions import init_model
from projects.flow_style_vton.models import FlowStyleVTON
init_default_scope('mmagic')
config = 'configs/flow_style_vton_PFAFN_epoch_101.py'
parser = ArgumentParser()
parser.add_argument('--gpu_id', type=int, default=0)
parser.add_argument(
'--loadSize', type=int, default=512, help='scale images to this size')
parser.add_argument(
'--fineSize', type=int, default=512, help='then crop to this size')
parser.add_argument('--dataroot', type=str, default='VITON_test')
parser.add_argument('--test_pairs', type=str, default='test_pairs.txt')
parser.add_argument(
'--resize_or_crop',
type=str,
default='scale_width',
help='scaling and cropping of images at load time \
[resize_and_crop|crop|scale_width|scale_width_and_crop]')
parser.add_argument('--phase', type=str, default='test')
parser.add_argument('--isTrain', default=False)
parser.add_argument(
'--no_flip',
action='store_true',
help='if specified, do not flip the images for data argumentation')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--output_dir', type=str, default='inference_results')
opt = parser.parse_args()
dataset = AlignedDataset(opt)
dataloader = DataLoader(dataset, opt.batch_size, num_workers=opt.num_workers)
device = torch.device(
f'cuda:{opt.gpu_id}' if torch.cuda.is_available() else 'cpu')
# pretrained is set inside the config
model = init_model(config).to(device).eval()
assert isinstance(model, FlowStyleVTON)
os.makedirs('our_t_results', exist_ok=True)
os.makedirs('im_gar_flow_wg', exist_ok=True)
for i, data in enumerate(tqdm(dataloader)):
p_tryon, combine = model.infer(data)
save_image(
p_tryon,
os.path.join('our_t_results', data['p_name'][0]),
nrow=int(1),
normalize=True,
value_range=(-1, 1))
save_image(
combine,
os.path.join('im_gar_flow_wg', data['p_name'][0]),
nrow=int(1),
normalize=True,
range=(-1, 1),
)