Skip to content

Commit

Permalink
fix i2i
Browse files Browse the repository at this point in the history
  • Loading branch information
Vovanm88 committed Aug 16, 2024
1 parent ef0f981 commit 6415e4a
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,20 @@ def denoise(
):
i = 0

#init_latents = rearrange(init_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if image2image_strength is not None and orig_image is not None:

t_idx = int((1 - np.clip(image2image_strength, 0.0, 1.0)) * len(timesteps))
t = timesteps[t_idx]
timesteps = timesteps[t_idx:]
orig_image = rearrange(orig_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
img = t * img + (1.0 - t) * orig_image.to(img.dtype)
# this is ignored for schnell
orig_image.to(img.device)
img = float(0.0+t) * img + (1.0 - t) * orig_image.to(img.dtype)

if hasattr(model, "guidance_in"):
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
else:
# this is ignored for schnell
guidance_vec = None
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
Expand Down Expand Up @@ -231,11 +235,12 @@ def denoise_controlnet(
timesteps = timesteps[t_idx:]
orig_image = rearrange(orig_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
orig_image.to(img.device)
img = t * img + (1.0 - t) * orig_image.to(img.dtype)
# this is ignored for schnell
img = float(0.0+t) * img + (1.0 - t) * orig_image.to(img.dtype)

if hasattr(model, "guidance_in"):
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
else:
# this is ignored for schnell
guidance_vec = None
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
Expand Down

0 comments on commit 6415e4a

Please sign in to comment.