Skip to content

Commit 75933d7

Browse files
authored
Merge pull request #1960 from kohya-ss/sd3_safetensors_merge
Sd3 safetensors merge
2 parents 3d79239 + aa2bde7 commit 75933d7

File tree

3 files changed

+177
-6
lines changed

3 files changed

+177
-6
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ The command to install PyTorch is as follows:
1414

1515
### Recent Updates
1616

17+
Mar 6, 2025:
18+
19+
- Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960)
20+
1721
Feb 26, 2025:
1822

1923
- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)

library/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,10 @@ def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
261261

262262

263263
class MemoryEfficientSafeOpen:
264-
# does not support metadata loading
265264
def __init__(self, filename):
266265
self.filename = filename
267-
self.header, self.header_size = self._read_header()
268266
self.file = open(filename, "rb")
267+
self.header, self.header_size = self._read_header()
269268

270269
def __enter__(self):
271270
return self
@@ -276,6 +275,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
276275
def keys(self):
277276
return [k for k in self.header.keys() if k != "__metadata__"]
278277

278+
def metadata(self) -> Dict[str, str]:
279+
return self.header.get("__metadata__", {})
280+
279281
def get_tensor(self, key):
280282
if key not in self.header:
281283
raise KeyError(f"Tensor '{key}' not found in the file")
@@ -293,10 +295,9 @@ def get_tensor(self, key):
293295
return self._deserialize_tensor(tensor_bytes, metadata)
294296

295297
def _read_header(self):
296-
with open(self.filename, "rb") as f:
297-
header_size = struct.unpack("<Q", f.read(8))[0]
298-
header_json = f.read(header_size).decode("utf-8")
299-
return json.loads(header_json), header_size
298+
header_size = struct.unpack("<Q", self.file.read(8))[0]
299+
header_json = self.file.read(header_size).decode("utf-8")
300+
return json.loads(header_json), header_size
300301

301302
def _deserialize_tensor(self, tensor_bytes, metadata):
302303
dtype = self._get_torch_dtype(metadata["dtype"])

tools/merge_sd3_safetensors.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import argparse
2+
import os
3+
import gc
4+
from typing import Dict, Optional, Union
5+
import torch
6+
from safetensors.torch import safe_open
7+
8+
from library.utils import setup_logging
9+
from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype
10+
11+
setup_logging()
12+
import logging
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
def merge_safetensors(
18+
dit_path: str,
19+
vae_path: Optional[str] = None,
20+
clip_l_path: Optional[str] = None,
21+
clip_g_path: Optional[str] = None,
22+
t5xxl_path: Optional[str] = None,
23+
output_path: str = "merged_model.safetensors",
24+
device: str = "cpu",
25+
save_precision: Optional[str] = None,
26+
):
27+
"""
28+
Merge multiple safetensors files into a single file
29+
30+
Args:
31+
dit_path: Path to the DiT/MMDiT model
32+
vae_path: Path to the VAE model
33+
clip_l_path: Path to the CLIP-L model
34+
clip_g_path: Path to the CLIP-G model
35+
t5xxl_path: Path to the T5-XXL model
36+
output_path: Path to save the merged model
37+
device: Device to load tensors to
38+
save_precision: Target dtype for model weights (e.g. 'fp16', 'bf16')
39+
"""
40+
logger.info("Starting to merge safetensors files...")
41+
42+
# Convert save_precision string to torch dtype if specified
43+
if save_precision:
44+
target_dtype = str_to_dtype(save_precision)
45+
else:
46+
target_dtype = None
47+
48+
# 1. Get DiT metadata if available
49+
metadata = None
50+
try:
51+
with safe_open(dit_path, framework="pt") as f:
52+
metadata = f.metadata() # may be None
53+
if metadata:
54+
logger.info(f"Found metadata in DiT model: {metadata}")
55+
except Exception as e:
56+
logger.warning(f"Failed to read metadata from DiT model: {e}")
57+
58+
# 2. Create empty merged state dict
59+
merged_state_dict = {}
60+
61+
# 3. Load and merge each model with memory management
62+
63+
# DiT/MMDiT - prefix: model.diffusion_model.
64+
# This state dict may have VAE keys.
65+
logger.info(f"Loading DiT model from {dit_path}")
66+
dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True, dtype=target_dtype)
67+
logger.info(f"Adding DiT model with {len(dit_state_dict)} keys")
68+
for key, value in dit_state_dict.items():
69+
if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."):
70+
merged_state_dict[key] = value
71+
else:
72+
merged_state_dict[f"model.diffusion_model.{key}"] = value
73+
# Free memory
74+
del dit_state_dict
75+
gc.collect()
76+
77+
# VAE - prefix: first_stage_model.
78+
# May be omitted if VAE is already included in DiT model.
79+
if vae_path:
80+
logger.info(f"Loading VAE model from {vae_path}")
81+
vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True, dtype=target_dtype)
82+
logger.info(f"Adding VAE model with {len(vae_state_dict)} keys")
83+
for key, value in vae_state_dict.items():
84+
if key.startswith("first_stage_model."):
85+
merged_state_dict[key] = value
86+
else:
87+
merged_state_dict[f"first_stage_model.{key}"] = value
88+
# Free memory
89+
del vae_state_dict
90+
gc.collect()
91+
92+
# CLIP-L - prefix: text_encoders.clip_l.
93+
if clip_l_path:
94+
logger.info(f"Loading CLIP-L model from {clip_l_path}")
95+
clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True, dtype=target_dtype)
96+
logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys")
97+
for key, value in clip_l_state_dict.items():
98+
if key.startswith("text_encoders.clip_l.transformer."):
99+
merged_state_dict[key] = value
100+
else:
101+
merged_state_dict[f"text_encoders.clip_l.transformer.{key}"] = value
102+
# Free memory
103+
del clip_l_state_dict
104+
gc.collect()
105+
106+
# CLIP-G - prefix: text_encoders.clip_g.
107+
if clip_g_path:
108+
logger.info(f"Loading CLIP-G model from {clip_g_path}")
109+
clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True, dtype=target_dtype)
110+
logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys")
111+
for key, value in clip_g_state_dict.items():
112+
if key.startswith("text_encoders.clip_g.transformer."):
113+
merged_state_dict[key] = value
114+
else:
115+
merged_state_dict[f"text_encoders.clip_g.transformer.{key}"] = value
116+
# Free memory
117+
del clip_g_state_dict
118+
gc.collect()
119+
120+
# T5-XXL - prefix: text_encoders.t5xxl.
121+
if t5xxl_path:
122+
logger.info(f"Loading T5-XXL model from {t5xxl_path}")
123+
t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True, dtype=target_dtype)
124+
logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys")
125+
for key, value in t5xxl_state_dict.items():
126+
if key.startswith("text_encoders.t5xxl.transformer."):
127+
merged_state_dict[key] = value
128+
else:
129+
merged_state_dict[f"text_encoders.t5xxl.transformer.{key}"] = value
130+
# Free memory
131+
del t5xxl_state_dict
132+
gc.collect()
133+
134+
# 4. Save merged state dict
135+
logger.info(f"Saving merged model to {output_path} with {len(merged_state_dict)} keys total")
136+
mem_eff_save_file(merged_state_dict, output_path, metadata)
137+
logger.info("Successfully merged safetensors files")
138+
139+
140+
def main():
141+
parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file")
142+
parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model")
143+
parser.add_argument("--vae", help="Path to the VAE model. May be omitted if VAE is included in DiT model")
144+
parser.add_argument("--clip_l", help="Path to the CLIP-L model")
145+
parser.add_argument("--clip_g", help="Path to the CLIP-G model")
146+
parser.add_argument("--t5xxl", help="Path to the T5-XXL model")
147+
parser.add_argument("--output", default="merged_model.safetensors", help="Path to save the merged model")
148+
parser.add_argument("--device", default="cpu", help="Device to load tensors to")
149+
parser.add_argument("--save_precision", type=str, help="Precision to save the model in (e.g., 'fp16', 'bf16', 'float16', etc.)")
150+
151+
args = parser.parse_args()
152+
153+
merge_safetensors(
154+
dit_path=args.dit,
155+
vae_path=args.vae,
156+
clip_l_path=args.clip_l,
157+
clip_g_path=args.clip_g,
158+
t5xxl_path=args.t5xxl,
159+
output_path=args.output,
160+
device=args.device,
161+
save_precision=args.save_precision,
162+
)
163+
164+
165+
if __name__ == "__main__":
166+
main()

0 commit comments

Comments
 (0)