Skip to content

Commit

Permalink
add pbar to unipc
Browse files Browse the repository at this point in the history
  • Loading branch information
vladmandic authored Mar 13, 2023
1 parent dfeee78 commit 03a80f1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion modules/models/diffusion/uni_pc/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def sample(self,
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for UniPC sampling is {size}')
# print(f'Data shape for UniPC sampling is {size}')

device = self.model.betas.device
if x_T is None:
Expand Down
5 changes: 3 additions & 2 deletions modules/models/diffusion/uni_pc/uni_pc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn.functional as F
import math
from tqdm.auto import trange


class NoiseScheduleVP:
Expand Down Expand Up @@ -750,7 +751,7 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time
if method == 'multistep':
assert steps >= order, "UniPC order must be < sampling steps"
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
#print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
assert timesteps.shape[0] - 1 == steps
with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0]))
Expand All @@ -766,7 +767,7 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time
self.after_update(x, model_x)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
for step in range(order, steps + 1):
for step in trange(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
Expand Down

0 comments on commit 03a80f1

Please sign in to comment.