Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/features/hidiffusion.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
---
title: HiDiffusion
---

# HiDiffusion

HiDiffusion is an optional denoising enhancement that can improve detail and structure at higher resolutions for SD 1.5 and SDXL. It modifies the UNet during denoising and is most noticeable at 1536px and above.

Learn more: https://github.com/megvii-research/HiDiffusion

## Where to find the switches

1. Open the **Canvas** tab.
2. Expand **Advanced Settings**.
3. In the **Advanced** grid, enable **HiDiffusion** and optionally adjust the two sub‑toggles and ratios:
- **HiDiffusion: RAU‑Net**
- **HiDiffusion: Window Attention**
- **HiDiffusion: T1 Ratio**
- **HiDiffusion: T2 Ratio**

## What the switches do

- **HiDiffusion**
Enables the HiDiffusion patch for denoising. Use this for high‑resolution generations; the effect is subtle at lower sizes.

- **HiDiffusion: RAU‑Net**
Enables RAU‑Net blocks. This typically improves structure and mid‑frequency detail, especially at larger resolutions.

- **HiDiffusion: Window Attention**
Enables windowed attention blocks. This can boost local texture/detail, but may slightly affect global coherence in some prompts.

- **HiDiffusion: T1 Ratio**
Controls when HiDiffusion switches into its mid‑stage behavior. Lower values switch earlier; higher values preserve global structure longer.

- **HiDiffusion: T2 Ratio**
Controls when HiDiffusion switches into its late‑stage behavior. Higher values keep window attention active longer and can sharpen local detail.

## Tips

- Try **1536–2048 px** for the clearest benefits (SDXL).
- If results look worse, disable **Window Attention** first, then RAU‑Net.
- Effects vary by scheduler and model; compare with the same seed for a fair test.

---

TODO: Decide whether HiDiffusion toggles and ratios should always be emitted in metadata for recall.
115 changes: 91 additions & 24 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import os
from contextlib import ExitStack
from contextlib import ExitStack, nullcontext
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -65,6 +65,7 @@
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.hidiffusion import HiDiffusionExt
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
Expand All @@ -73,6 +74,7 @@
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.hidiffusion_utils import hidiffusion_patch
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
Expand Down Expand Up @@ -191,6 +193,35 @@ class DenoiseLatentsInvocation(BaseInvocation):
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
hidiffusion: bool = InputField(
default=False,
description=FieldDescriptions.hidiffusion,
title="HiDiffusion",
)
hidiffusion_raunet: bool = InputField(
default=True,
description=FieldDescriptions.hidiffusion_raunet,
title="HiDiffusion: RAU-Net",
)
hidiffusion_window_attn: bool = InputField(
default=True,
description=FieldDescriptions.hidiffusion_window_attn,
title="HiDiffusion: Window Attention",
)
hidiffusion_t1_ratio: float = InputField(
default=0.4,
ge=0,
le=1,
description=FieldDescriptions.hidiffusion_t1_ratio,
title="HiDiffusion: T1 Ratio",
)
hidiffusion_t2_ratio: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.hidiffusion_t2_ratio,
title="HiDiffusion: T2 Ratio",
)
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
Expand Down Expand Up @@ -486,6 +517,14 @@ def prep_control_data(

return controlnet_data

@staticmethod
def _get_hidiffusion_name_or_path(unet_config: AnyModelConfig) -> Optional[str]:
return (
getattr(unet_config, "source", None)
or getattr(unet_config, "path", None)
or getattr(unet_config, "name", None)
)

@staticmethod
def parse_controlnet_field(
exit_stack: ExitStack,
Expand Down Expand Up @@ -837,6 +876,7 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
hidiffusion_name_or_path = self._get_hidiffusion_name_or_path(unet_config)

conditioning_data = self.get_conditioning_data(
context=context,
Expand Down Expand Up @@ -874,6 +914,16 @@ def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

ext_manager.add_extension(PreviewExt(step_callback))
if self.hidiffusion:
ext_manager.add_extension(
HiDiffusionExt(
name_or_path=hidiffusion_name_or_path,
apply_raunet=self.hidiffusion_raunet,
apply_window_attn=self.hidiffusion_window_attn,
t1_ratio=self.hidiffusion_t1_ratio,
t2_ratio=self.hidiffusion_t2_ratio,
)
)

### cfg rescale
if self.cfg_rescale_multiplier > 0:
Expand Down Expand Up @@ -940,14 +990,17 @@ def step_callback(state: PipelineIntermediateState) -> None:
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

with (
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
with ExitStack() as unet_stack:
cached_weights, unet = unet_stack.enter_context(context.models.load(self.unet.unet).model_on_device())
unet._num_timesteps = timesteps.shape[0]
unet_stack.enter_context(
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls)
)
# ext: controlnet
ext_manager.patch_extensions(denoise_ctx),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(unet, cached_weights),
):
unet_stack.enter_context(ext_manager.patch_extensions(denoise_ctx))
# ext: freeu, seamless, ip adapter, lora, hidiffusion
unet_stack.enter_context(ext_manager.patch_unet(unet, cached_weights))

sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
Expand Down Expand Up @@ -997,6 +1050,7 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
hidiffusion_name_or_path = self._get_hidiffusion_name_or_path(unet_config)

def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
Expand Down Expand Up @@ -1084,23 +1138,36 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
denoising_end=self.denoising_end,
seed=seed,
)
pipeline._num_timesteps = timesteps.shape[0]

result_latents = pipeline.latents_from_embeddings(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
mask=mask,
masked_latents=masked_latents,
is_gradient_mask=gradient_mask,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=controlnet_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=step_callback,
)
with (
hidiffusion_patch(
pipeline,
name_or_path=hidiffusion_name_or_path,
apply_raunet=self.hidiffusion_raunet,
apply_window_attn=self.hidiffusion_window_attn,
t1_ratio=self.hidiffusion_t1_ratio,
t2_ratio=self.hidiffusion_t2_ratio,
)
if self.hidiffusion
else nullcontext()
):
result_latents = pipeline.latents_from_embeddings(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
mask=mask,
masked_latents=masked_latents,
is_gradient_mask=gradient_mask,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=controlnet_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=step_callback,
)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
Expand Down
5 changes: 5 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ class FieldDescriptions:
denoising_end = "When to stop denoising, expressed a percentage of total steps"
cfg_scale = "Classifier-Free Guidance scale"
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
hidiffusion = "Apply HiDiffusion (RAU-Net + MSW-MSA) for higher-resolution denoising"
hidiffusion_raunet = "Apply HiDiffusion RAU-Net blocks"
hidiffusion_window_attn = "Apply HiDiffusion window attention blocks"
hidiffusion_t1_ratio = "Override HiDiffusion early switch threshold (T1 ratio)"
hidiffusion_t2_ratio = "Override HiDiffusion late switch threshold (T2 ratio)"
scheduler = "Scheduler to use during inference"
positive_cond = "Positive conditioning tensor"
negative_cond = "Negative conditioning tensor"
Expand Down
6 changes: 6 additions & 0 deletions invokeai/app/invocations/metadata_linked.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,12 @@ def _loras_to_json(obj: Union[Any, list[Any]]):
md.update({"denoising_end": self.denoising_end})
md.update({"scheduler": self.scheduler})
md.update({"model": self.unet.unet})
if self.hidiffusion:
md.update({"hidiffusion": self.hidiffusion})
md.update({"hidiffusion_raunet": self.hidiffusion_raunet})
md.update({"hidiffusion_window_attn": self.hidiffusion_window_attn})
md.update({"hidiffusion_t1_ratio": self.hidiffusion_t1_ratio})
md.update({"hidiffusion_t2_ratio": self.hidiffusion_t2_ratio})
if isinstance(self.control, ControlField) or (isinstance(self.control, list) and len(self.control) > 0):
md.update({"controlnets": _to_json(self.control)})
if isinstance(self.ip_adapter, IPAdapterField) or (
Expand Down
Loading
Loading