diff --git a/sampling.py b/sampling.py index a4647a2..8d8ea7a 100644 --- a/sampling.py +++ b/sampling.py @@ -252,22 +252,17 @@ def denoise_controlnet( #init_latents = rearrange(init_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if image2image_strength is not None and orig_image is not None: - t_idx = np.clip( - int((np.clip(image2image_strength, 0.0, 1.0)) * len(timesteps)), 0, 1 - ) + t_idx = int((1 - np.clip(image2image_strength, 0.0, 1.0)) * len(timesteps)) t = timesteps[t_idx] - try: - timesteps = timesteps[t_idx:] - except: - pass + timesteps = timesteps[t_idx:] orig_image = rearrange(orig_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).to(img.device, dtype = img.dtype) img = t * img + (1.0 - t) * orig_image - controlnet.to(img.device, dtype=img.dtype) img_ids=img_ids.to(img.device, dtype=img.dtype) - controlnet_cond=controlnet_cond.to(img.device, dtype=img.dtype) txt=txt.to(img.device, dtype=img.dtype) txt_ids=txt_ids.to(img.device, dtype=img.dtype) vec=vec.to(img.device, dtype=img.dtype) + controlnet.to(img.device, dtype=img.dtype) + controlnet_cond=controlnet_cond.to(img.device, dtype=img.dtype) if hasattr(model, "guidance_in"): guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) else: