@@ -31,6 +31,7 @@ def __getitem__(self, idx):
31
31
parser .add_argument ('-output_dir' , type = str , default = 'runs' , help = 'output data directory' )
32
32
parser .add_argument ('-cache_dir' , type = str , default = 'cache' , help = 'cache directory for model weights' )
33
33
parser .add_argument ('-duplicates' , type = int , default = 1 , help = 'How many HR images to produce for every image in the input directory' )
34
+ parser .add_argument ('-batch_size' , type = int , default = 1 , help = 'Batch size to use during optimization' )
34
35
35
36
#PULSE arguments
36
37
parser .add_argument ('-seed' , type = int , help = 'manual seed to use' )
@@ -47,12 +48,13 @@ def __getitem__(self, idx):
47
48
parser .add_argument ('-save_intermediate' , action = 'store_true' , help = 'Whether to store and save intermediate HR and LR images during optimization' )
48
49
49
50
kwargs = vars (parser .parse_args ())
51
+ kwargs ["save_intermediate" ]= True
50
52
51
53
dataset = Images (kwargs ["input_dir" ], duplicates = kwargs ["duplicates" ])
52
54
out_path = Path (kwargs ["output_dir" ])
53
55
out_path .mkdir (parents = True , exist_ok = True )
54
56
55
- dataloader = DataLoader (dataset , batch_size = 1 )
57
+ dataloader = DataLoader (dataset , batch_size = kwargs [ "batch_size" ] )
56
58
57
59
model = PULSE (cache_dir = kwargs ["cache_dir" ])
58
60
model = DataParallel (model )
@@ -61,21 +63,20 @@ def __getitem__(self, idx):
61
63
62
64
for ref_im , ref_im_name in dataloader :
63
65
if (kwargs ["save_intermediate" ]):
64
- out_im , int_HR , int_LR = model (ref_im ,** kwargs )
65
- else :
66
- out_im = model (ref_im ,** kwargs )
67
-
68
- for i in range (len (out_im )):
69
- toPIL (out_im [i ].cpu ().detach ().clamp (0 , 1 )).save (
70
- out_path / f"{ ref_im_name [i ]} .png" )
71
- if (kwargs ["save_intermediate" ]):
72
- padding = ceil (log10 (100 ))
66
+ padding = ceil (log10 (100 ))
67
+ for i in range (kwargs ["batch_size" ]):
73
68
int_path_HR = Path (out_path / ref_im_name [i ] / "HR" )
74
69
int_path_LR = Path (out_path / ref_im_name [i ] / "LR" )
75
70
int_path_HR .mkdir (parents = True , exist_ok = True )
76
71
int_path_LR .mkdir (parents = True , exist_ok = True )
77
- for j ,(HR ,LR ) in enumerate (zip (int_HR ,int_LR )):
72
+ for j ,(HR ,LR ) in enumerate (model (ref_im ,** kwargs )):
73
+ for i in range (kwargs ["batch_size" ]):
78
74
toPIL (HR [i ].cpu ().detach ().clamp (0 , 1 )).save (
79
75
int_path_HR / f"{ ref_im_name [i ]} _{ j :0{padding }} .png" )
80
76
toPIL (LR [i ].cpu ().detach ().clamp (0 , 1 )).save (
81
77
int_path_LR / f"{ ref_im_name [i ]} _{ j :0{padding }} .png" )
78
+ else :
79
+ out_im = model (ref_im ,** kwargs )
80
+ for i in range (kwargs ["batch_size" ]):
81
+ toPIL (out_im [i ].cpu ().detach ().clamp (0 , 1 )).save (
82
+ out_path / f"{ ref_im_name [i ]} .png" )
0 commit comments