Skip to content

Commit de41a83

Browse files
author
Alex Damian
committed
Turned PULSE into a generator when saving intermediate steps
1 parent 5e7897a commit de41a83

File tree

2 files changed

+17
-23
lines changed

2 files changed

+17
-23
lines changed

PULSE.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ def forward(self, ref_im,
126126
min_loss = np.inf
127127
best_summary = ""
128128
start_t = time.time()
129-
if(save_intermediate):
130-
int_HR = []
131-
int_LR = []
132129

133130
if self.verbose: print("Optimizing")
134131
for j in range(steps):
@@ -150,18 +147,17 @@ def forward(self, ref_im,
150147
loss, loss_dict = loss_builder(latent_in, gen_im)
151148
loss_dict['TOTAL'] = loss
152149

153-
# Save intermediate HR and LR images
154-
if(save_intermediate):
155-
int_HR.append(gen_im.cpu().detach().clamp(0, 1))
156-
int_LR.append(loss_builder.D(gen_im).cpu().detach().clamp(0, 1))
157-
158150
# Save best summary for log
159151
if(loss < min_loss):
160152
min_loss = loss
161153
best_summary = f'BEST ({j+1}) | '+' | '.join(
162154
[f'{x}: {y:.4f}' for x, y in loss_dict.items()])
163155
best_im = gen_im.clone()
164156

157+
# Save intermediate HR and LR images
158+
if(save_intermediate):
159+
yield (best_im.cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
160+
165161
loss.backward()
166162
opt.step()
167163
scheduler.step()
@@ -170,7 +166,4 @@ def forward(self, ref_im,
170166
current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
171167
if self.verbose: print(best_summary+current_info)
172168

173-
if(save_intermediate):
174-
return best_im.cpu().detach().clamp(0,1), int_HR, int_LR
175-
else:
176-
return best_im.cpu().detach().clamp(0,1)
169+
return best_im.cpu().detach().clamp(0,1)

run.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __getitem__(self, idx):
3131
parser.add_argument('-output_dir', type=str, default='runs', help='output data directory')
3232
parser.add_argument('-cache_dir', type=str, default='cache', help='cache directory for model weights')
3333
parser.add_argument('-duplicates', type=int, default=1, help='How many HR images to produce for every image in the input directory')
34+
parser.add_argument('-batch_size', type=int, default=1, help='Batch size to use during optimization')
3435

3536
#PULSE arguments
3637
parser.add_argument('-seed', type=int, help='manual seed to use')
@@ -47,12 +48,13 @@ def __getitem__(self, idx):
4748
parser.add_argument('-save_intermediate', action='store_true', help='Whether to store and save intermediate HR and LR images during optimization')
4849

4950
kwargs = vars(parser.parse_args())
51+
kwargs["save_intermediate"]=True
5052

5153
dataset = Images(kwargs["input_dir"], duplicates=kwargs["duplicates"])
5254
out_path = Path(kwargs["output_dir"])
5355
out_path.mkdir(parents=True, exist_ok=True)
5456

55-
dataloader = DataLoader(dataset, batch_size=1)
57+
dataloader = DataLoader(dataset, batch_size=kwargs["batch_size"])
5658

5759
model = PULSE(cache_dir=kwargs["cache_dir"])
5860
model = DataParallel(model)
@@ -61,21 +63,20 @@ def __getitem__(self, idx):
6163

6264
for ref_im, ref_im_name in dataloader:
6365
if(kwargs["save_intermediate"]):
64-
out_im, int_HR, int_LR = model(ref_im,**kwargs)
65-
else:
66-
out_im = model(ref_im,**kwargs)
67-
68-
for i in range(len(out_im)):
69-
toPIL(out_im[i].cpu().detach().clamp(0, 1)).save(
70-
out_path / f"{ref_im_name[i]}.png")
71-
if(kwargs["save_intermediate"]):
72-
padding = ceil(log10(100))
66+
padding = ceil(log10(100))
67+
for i in range(kwargs["batch_size"]):
7368
int_path_HR = Path(out_path / ref_im_name[i] / "HR")
7469
int_path_LR = Path(out_path / ref_im_name[i] / "LR")
7570
int_path_HR.mkdir(parents=True, exist_ok=True)
7671
int_path_LR.mkdir(parents=True, exist_ok=True)
77-
for j,(HR,LR) in enumerate(zip(int_HR,int_LR)):
72+
for j,(HR,LR) in enumerate(model(ref_im,**kwargs)):
73+
for i in range(kwargs["batch_size"]):
7874
toPIL(HR[i].cpu().detach().clamp(0, 1)).save(
7975
int_path_HR / f"{ref_im_name[i]}_{j:0{padding}}.png")
8076
toPIL(LR[i].cpu().detach().clamp(0, 1)).save(
8177
int_path_LR / f"{ref_im_name[i]}_{j:0{padding}}.png")
78+
else:
79+
out_im = model(ref_im,**kwargs)
80+
for i in range(kwargs["batch_size"]):
81+
toPIL(out_im[i].cpu().detach().clamp(0, 1)).save(
82+
out_path / f"{ref_im_name[i]}.png")

0 commit comments

Comments
 (0)