Skip to content

Commit

Permalink
add conversion from SD3.5 diffusers to safetensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Nerogar committed Oct 23, 2024
1 parent 43cac35 commit d7a4e73
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ Hardware for development sponsored by https://shakker.ai

## Features

- **Supported models**: FLUX.1, Stable Diffusion 1.5, 2.0, 2.1, 3.0, SDXL, Würstchen-v2, Stable Cascade, PixArt-Alpha,
PixArt-Sigma and inpainting models
- **Supported models**: FLUX.1, Stable Diffusion 1.5, 2.0, 2.1, 3.0, 3.5, SDXL, Würstchen-v2, Stable Cascade,
PixArt-Alpha, PixArt-Sigma and inpainting models
- **Model formats**: diffusers and ckpt models
- **Training methods**: Full fine-tuning, LoRA, embeddings
- **Masked Training**: Let the training focus on just certain parts of the samples.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __save_safetensors(
dtype: torch.dtype | None,
):
state_dict = convert_sd3_diffusers_to_ckpt(
model.model_type,
model.vae.state_dict(),
model.transformer.state_dict(),
model.text_encoder_1.state_dict() if model.text_encoder_1 is not None else None,
Expand Down
2 changes: 2 additions & 0 deletions modules/ui/ConvertModelUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def main_frame(self, master):
("Stable Diffusion 2.0", ModelType.STABLE_DIFFUSION_20),
("Stable Diffusion 2.0 Inpainting", ModelType.STABLE_DIFFUSION_20_INPAINTING),
("Stable Diffusion 2.1", ModelType.STABLE_DIFFUSION_21),
("Stable Diffusion 3", ModelType.STABLE_DIFFUSION_3),
("Stable Diffusion 3.5", ModelType.STABLE_DIFFUSION_35),
("Stable Diffusion XL 1.0 Base", ModelType.STABLE_DIFFUSION_XL_10_BASE),
("Stable Diffusion XL 1.0 Base Inpainting", ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING),
("Wuerstchen v2", ModelType.WUERSTCHEN_2),
Expand Down
21 changes: 16 additions & 5 deletions modules/util/convert/convert_sd3_diffusers_to_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import torch
from torch import Tensor

from modules.util.enum.ModelType import ModelType


def __swap_chunks(tensor:Tensor) -> Tensor:
chunk_0, chunk_1 = tensor.chunk(2, dim=0)
return torch.cat([chunk_1, chunk_0], dim=0)

def __map_transformer_block(in_states: dict, out_prefix: str, in_prefix: str, is_last:bool) -> dict:
def __map_transformer_block(model_type: ModelType, in_states: dict, out_prefix: str, in_prefix: str, is_last:bool) -> dict:
out_states = {}

out_states[util.combine(out_prefix, "x_block.attn.qkv.weight")] = torch.cat([
Expand Down Expand Up @@ -48,17 +50,25 @@ def __map_transformer_block(in_states: dict, out_prefix: str, in_prefix: str, is
out_states[util.combine(out_prefix, "context_block.adaLN_modulation.1.weight")] = __swap_chunks(in_states[util.combine(in_prefix, "norm1_context.linear.weight")])
out_states[util.combine(out_prefix, "context_block.adaLN_modulation.1.bias")] = __swap_chunks(in_states[util.combine(in_prefix, "norm1_context.linear.bias")])

if model_type.is_stable_diffusion_3_5():
out_states[util.combine(out_prefix, "context_block.attn.ln_k.weight")] = in_states[util.combine(in_prefix, "attn.norm_added_k.weight")]
out_states[util.combine(out_prefix, "context_block.attn.ln_q.weight")] = in_states[util.combine(in_prefix, "attn.norm_added_q.weight")]

out_states |= util.map_wb(in_states, util.combine(out_prefix, "x_block.mlp.fc1"), util.combine(in_prefix, "ff.net.0.proj"))
out_states |= util.map_wb(in_states, util.combine(out_prefix, "x_block.mlp.fc2"), util.combine(in_prefix, "ff.net.2"))

if model_type.is_stable_diffusion_3_5():
out_states[util.combine(out_prefix, "x_block.attn.ln_k.weight")] = in_states[util.combine(in_prefix, "attn.norm_k.weight")]
out_states[util.combine(out_prefix, "x_block.attn.ln_k.weight")] = in_states[util.combine(in_prefix, "attn.norm_q.weight")]

if not is_last:
out_states |= util.map_wb(in_states, util.combine(out_prefix, "context_block.mlp.fc1"), util.combine(in_prefix, "ff_context.net.0.proj"))
out_states |= util.map_wb(in_states, util.combine(out_prefix, "context_block.mlp.fc2"), util.combine(in_prefix, "ff_context.net.2"))

return out_states


def __map_transformer(in_states: dict, out_prefix: str, in_prefix: str) -> dict:
def __map_transformer(model_type: ModelType, in_states: dict, out_prefix: str, in_prefix: str) -> dict:
out_states = {}

out_states[util.combine(out_prefix, "pos_embed")] = in_states[util.combine(in_prefix, "pos_embed.pos_embed")]
Expand All @@ -73,10 +83,10 @@ def __map_transformer(in_states: dict, out_prefix: str, in_prefix: str) -> dict:
out_states |= util.map_wb(in_states, util.combine(out_prefix, "y_embedder.mlp.0"), util.combine(in_prefix, "time_text_embed.text_embedder.linear_1"))
out_states |= util.map_wb(in_states, util.combine(out_prefix, "y_embedder.mlp.2"), util.combine(in_prefix, "time_text_embed.text_embedder.linear_2"))

num_layers = 24
num_layers = 38 if model_type.is_stable_diffusion_3_5() else 24
for i in range(num_layers):
is_last = i == (num_layers - 1)
out_states |= __map_transformer_block(in_states, util.combine(out_prefix, f"joint_blocks.{i}"), util.combine(in_prefix, f"transformer_blocks.{i}"), is_last)
out_states |= __map_transformer_block(model_type, in_states, util.combine(out_prefix, f"joint_blocks.{i}"), util.combine(in_prefix, f"transformer_blocks.{i}"), is_last)

return out_states

Expand Down Expand Up @@ -130,6 +140,7 @@ def __map_t5_text_encoder(in_states: dict, out_prefix: str, in_prefix: str) -> d


def convert_sd3_diffusers_to_ckpt(
model_type: ModelType,
vae_state_dict: dict,
transformer_state_dict: dict,
text_encoder_1_state_dict: dict,
Expand All @@ -139,7 +150,7 @@ def convert_sd3_diffusers_to_ckpt(
state_dict = {}

state_dict |= util.map_vae(vae_state_dict, "first_stage_model", "")
state_dict |= __map_transformer(transformer_state_dict, "model.diffusion_model", "")
state_dict |= __map_transformer(model_type, transformer_state_dict, "model.diffusion_model", "")
if text_encoder_1_state_dict is not None:
state_dict |= __map_clip_text_encoder(text_encoder_1_state_dict, "text_encoders.clip_l.transformer", "")
if text_encoder_2_state_dict is not None:
Expand Down

0 comments on commit d7a4e73

Please sign in to comment.