Skip to content

Commit

Permalink
fix latent upscale not working if bs>1
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Apr 24, 2023
1 parent 1890535 commit a85fcfe
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ def __call__(

# encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype)
if init_image.size()[1:] == (height // 8, width // 8):
if init_image.size()[-2:] == (height // 8, width // 8):
init_latents = init_image
else:
if vae_batch_size >= batch_size:
Expand Down
8 changes: 7 additions & 1 deletion tools/latent_upscaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,13 @@ def create_upscaler(**kwargs):
model = Upscaler()

print(f"Loading weights from {weights}...")
model.load_state_dict(torch.load(weights, map_location=torch.device("cpu")))
if os.path.splitext(weights)[1] == ".safetensors":
from safetensors.torch import load_file

sd = load_file(weights)
else:
sd = torch.load(weights, map_location=torch.device("cpu"))
model.load_state_dict(sd)
return model


Expand Down

0 comments on commit a85fcfe

Please sign in to comment.