Skip to content

Commit d698d81

Browse files
committed
update
1 parent 6a62c3e commit d698d81

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

src/diffusers/models/controlnets/controlnet_sana.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class SanaControlNetOutput(BaseOutput):
4040

4141
class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
4242
_supports_gradient_checkpointing = True
43+
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
44+
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
4345

4446
@register_to_config
4547
def __init__(

src/diffusers/pipelines/sana/pipeline_sana_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
207207
bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
208208
# fmt: on
209209

210-
model_cpu_offload_seq = "text_encoder->transformer->vae"
211-
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
210+
model_cpu_offload_seq = "text_encoder->controlnet->transformer->vae"
211+
_callback_tensor_inputs = ["latents", "control_image", "prompt_embeds", "negative_prompt_embeds"]
212212

213213
def __init__(
214214
self,

0 commit comments

Comments
 (0)