Skip to content

Commit 5c8c2d4

Browse files
authored
[Speculative Decoding][MTP]Update extract_mtp_weight script and optimize config (#5183)
* update extract_mtp_model * modify config usage
1 parent edf0d09 commit 5c8c2d4

File tree

4 files changed

+43
-6
lines changed

4 files changed

+43
-6
lines changed

fastdeploy/config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,10 +1628,6 @@ def postprocess(self):
16281628
else:
16291629
self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len
16301630

1631-
self.scheduler_config.max_chunk_len = (
1632-
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens
1633-
)
1634-
16351631
if self.long_prefill_token_threshold == 0:
16361632
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)
16371633

fastdeploy/scheduler/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def __init__(self, args):
270270
self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler
271271
self.max_num_batched_tokens = 2048 # base token_num for text inputs
272272
self.max_extra_num_batched_tokens = 16384 # extra token_num for multimodal inputs
273-
self.max_chunk_len = 18432 # max supported token_num = max_num_batched_tokens + max_extra_num_batched_tokens
274273
self.max_num_seqs = 34
275274
self.splitwise_role = "mixed"
276275
self.config = None

fastdeploy/spec_decode/mtp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,13 @@ def _init_model_inputs(self):
355355
self.target_model_inputs["decoder_tile_ids_per_batch"]
356356
)
357357
self.model_inputs["target_hidden_states"] = paddle.full(
358-
[self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
358+
[
359+
self.fd_config.scheduler_config.max_num_batched_tokens
360+
+ self.fd_config.scheduler_config.max_extra_num_batched_tokens,
361+
self.model_config.hidden_size,
362+
],
363+
0,
364+
dtype="bfloat16",
359365
)
360366

361367
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))

scripts/extract_mtp_weight_from_safetensor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import argparse
1818
import json
1919
import os
20+
import re
2021

22+
import numpy as np
2123
import paddle
2224
from paddleformers.transformers.model_utils import shard_checkpoint
2325
from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
@@ -46,6 +48,28 @@ def parse_args():
4648
return parser.parse_args()
4749

4850

51+
def dtype_byte_size(dtype):
52+
"""
53+
Returns the size (in bytes) occupied by one parameter of type `dtype`.
54+
55+
Example:
56+
57+
```py
58+
>>> dtype_byte_size(paddle.float32)
59+
4
60+
```
61+
"""
62+
if str(dtype) in {"paddle.bool", "bool"}:
63+
return 1 / 8
64+
if str(dtype) in {"paddle.float8_e4m3fn", "paddle.float8_e5m2", "float8_e4m3fn", "float8_e5m2"}:
65+
return 1
66+
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
67+
if bit_search is None:
68+
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
69+
bit_size = int(bit_search.groups()[0])
70+
return bit_size // 8
71+
72+
4973
def extract_mtp_weights(input_dir: str) -> dict:
5074
"""
5175
Load all MTP-related weights from safetensors files in input_dir.
@@ -103,6 +127,18 @@ def save_safetensors(state_dict: dict, output_dir: str):
103127
logger.info(f"Saving shard: {save_path}")
104128
safe_save_file(shard, save_path, metadata={"format": "np"})
105129

130+
# If only one shard is returned, SAFE_WEIGHTS_INDEX_NAME will be null
131+
if len(shards) == 1:
132+
logger.info("Generate index file for single shard")
133+
weight_size = 0
134+
for key, weight in shards["model.safetensors"].items():
135+
weight_size += np.prod(weight.shape) * dtype_byte_size(weight.dtype)
136+
137+
index = {
138+
"metadata": {"total_size": int(weight_size)},
139+
"weight_map": {k: "model.safetensors" for k in shards["model.safetensors"].keys()},
140+
}
141+
106142
index_path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)
107143
with open(index_path, "w", encoding="utf-8") as f:
108144
json.dump(index, f, indent=2)

0 commit comments

Comments
 (0)