Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 57 additions & 14 deletions scripts/extract_mtp_weight_from_safetensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
import re
from typing import Dict

import numpy as np
import paddle
Expand Down Expand Up @@ -48,9 +49,10 @@ def parse_args():
return parser.parse_args()


def dtype_byte_size(dtype):
def dtype_byte_size(dtype) -> int:
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
NOTE: This returns an integer number of bytes for determinism in metadata.

Example:

Expand All @@ -59,20 +61,41 @@ def dtype_byte_size(dtype):
4
```
"""
if str(dtype) in {"paddle.bool", "bool"}:
return 1 / 8
if str(dtype) in {"paddle.float8_e4m3fn", "paddle.float8_e5m2", "float8_e4m3fn", "float8_e5m2"}:
s = str(dtype)

# bool is stored as 1 byte in most tensor formats; for metadata determinism, use 1.
if s in {"paddle.bool", "bool"}:
return 1

# Paddle float8 types
if s in {"paddle.float8_e4m3fn", "paddle.float8_e5m2", "float8_e4m3fn", "float8_e5m2"}:
return 1
bit_search = re.search(r"[^\d](\d+)$", str(dtype))

bit_search = re.search(r"[^\d](\d+)$", s)
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
if bit_size % 8 != 0:
raise ValueError(f"Unsupported dtype with non-byte-aligned bits: {dtype}.")
return bit_size // 8


def extract_mtp_weights(input_dir: str) -> dict:
def _sorted_required_files(weight_map: Dict[str, str]) -> list[str]:
"""
Collect shard file names that contain 'mtp' weights and return a stable sorted list.
"""
# Use set to unique then sort to make iteration deterministic across runs.
required_files = sorted({v for k, v in weight_map.items() if "mtp" in k})
return required_files


def extract_mtp_weights(input_dir: str) -> Dict[str, np.ndarray]:
"""
Load all MTP-related weights from safetensors files in input_dir.

Determinism:
- iterate shards in sorted order
- iterate tensor keys in sorted order
"""
index_path = os.path.join(input_dir, SAFE_WEIGHTS_INDEX_NAME)
if not os.path.isfile(index_path):
Expand All @@ -82,28 +105,37 @@ def extract_mtp_weights(input_dir: str) -> dict:
index = json.load(f)

weight_map = index.get("weight_map", {})
required_files = {v for k, v in weight_map.items() if "mtp" in k}
required_files = _sorted_required_files(weight_map)
logger.info(f"Found {len(required_files)} shards with MTP weights.")

state_dict = {}
state_dict: Dict[str, np.ndarray] = {}
for file_name in required_files:
file_path = os.path.join(input_dir, file_name)
if not os.path.isfile(file_path):
logger.warning(f"Shard not found: {file_path}")
continue

logger.info(f"Loading shard: {file_path}")
with safe_open(file_path, framework="np", device="cpu") as f:
for k in f.keys():
# Sort keys for determinism
for k in sorted(f.keys()):
if "mtp" in k:
state_dict[k] = f.get_tensor(k)

# Final sort of state_dict by key to make sharding deterministic.
state_dict = dict(sorted(state_dict.items(), key=lambda kv: kv[0]))

logger.info(f"Loaded {len(state_dict)} MTP weights.")
return state_dict


def save_safetensors(state_dict: dict, output_dir: str):
def save_safetensors(state_dict: Dict[str, object], output_dir: str):
"""
Save state_dict as safetensors shards into output_dir.

Determinism:
- ensure state_dict is ordered by key before sharding
- when generating single-shard index, sort keys for stable weight_map ordering
"""
os.makedirs(output_dir, exist_ok=True)

Expand All @@ -114,6 +146,9 @@ def save_safetensors(state_dict: dict, output_dir: str):
array = tensor.cpu().numpy()
state_dict[k] = array

# Ensure deterministic order before sharding
state_dict = dict(sorted(state_dict.items(), key=lambda kv: kv[0]))

logger.info("Sharding and saving safetensors.")
shards, index = shard_checkpoint(
state_dict,
Expand All @@ -122,21 +157,29 @@ def save_safetensors(state_dict: dict, output_dir: str):
shard_format="naive",
)

for shard_file, shard in shards.items():
# Save shards in stable order (by filename)
for shard_file in sorted(shards.keys()):
shard = shards[shard_file]
save_path = os.path.join(output_dir, shard_file)
logger.info(f"Saving shard: {save_path}")
safe_save_file(shard, save_path, metadata={"format": "np"})

# If only one shard is returned, SAFE_WEIGHTS_INDEX_NAME will be null
if len(shards) == 1:
logger.info("Generate index file for single shard")

# Be robust: infer the only shard file name
only_shard_file = next(iter(shards.keys()))
only_shard = shards[only_shard_file]

weight_size = 0
for key, weight in shards["model.safetensors"].items():
weight_size += np.prod(weight.shape) * dtype_byte_size(weight.dtype)
for key in sorted(only_shard.keys()):
weight = only_shard[key]
weight_size += int(np.prod(weight.shape)) * int(dtype_byte_size(weight.dtype))

index = {
"metadata": {"total_size": int(weight_size)},
"weight_map": {k: "model.safetensors" for k in shards["model.safetensors"].keys()},
"weight_map": {k: only_shard_file for k in sorted(only_shard.keys())},
}

index_path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)
Expand Down
Loading