Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SDXL Inpaint #15976

Merged
merged 2 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)

else:
if getattr(sd_model.model, "is_sdxl_inpaint", False):
if sd_model.is_sdxl_inpaint:
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
image_conditioning = images_tensor_to_samples(image_conditioning,
Expand Down Expand Up @@ -389,7 +389,7 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)

if getattr(self.sampler.model_wrap.inner_model.model, "is_sdxl_inpaint", False):
if self.sampler.model_wrap.inner_model.is_sdxl_inpaint:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)

# Dummy zero conditioning if we're not using inpainting or depth model.
Expand Down
20 changes: 13 additions & 7 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
# Set is_sdxl_inpaint flag.
diffusion_model_input = state_dict.get('diffusion_model.input_blocks.0.0.weight', None)
model.is_sdxl_inpaint = (
model.is_sdxl and
diffusion_model_input is not None and
diffusion_model_input.shape[1] == 9
)
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)

Expand All @@ -408,6 +401,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer

del state_dict

# Set is_sdxl_inpaint flag.
# Checks Unet structure to detect inpaint model. The inpaint model's
# checkpoint state_dict does not contain the key
# 'diffusion_model.input_blocks.0.0.weight'.
diffusion_model_input = model.model.state_dict().get(
'diffusion_model.input_blocks.0.0.weight'
)
model.is_sdxl_inpaint = (
model.is_sdxl and
diffusion_model_input is not None and
diffusion_model_input.shape[1] == 9
)

if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
Expand Down