Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,10 +1628,6 @@ def postprocess(self):
else:
self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len

self.scheduler_config.max_chunk_len = (
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens
)

if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)

Expand Down
1 change: 0 additions & 1 deletion fastdeploy/scheduler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def __init__(self, args):
self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler
self.max_num_batched_tokens = 2048 # base token_num for text inputs
self.max_extra_num_batched_tokens = 16384 # extra token_num for multimodal inputs
self.max_chunk_len = 18432 # max supported token_num = max_num_batched_tokens + max_extra_num_batched_tokens
self.max_num_seqs = 34
self.splitwise_role = "mixed"
self.config = None
Expand Down
8 changes: 7 additions & 1 deletion fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,13 @@ def _init_model_inputs(self):
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
self.model_inputs["target_hidden_states"] = paddle.full(
[self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
[
self.fd_config.scheduler_config.max_num_batched_tokens
+ self.fd_config.scheduler_config.max_extra_num_batched_tokens,
self.model_config.hidden_size,
],
0,
dtype="bfloat16",
)

tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
Expand Down
36 changes: 36 additions & 0 deletions scripts/extract_mtp_weight_from_safetensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import argparse
import json
import os
import re

import numpy as np
import paddle
from paddleformers.transformers.model_utils import shard_checkpoint
from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
Expand Down Expand Up @@ -46,6 +48,28 @@ def parse_args():
return parser.parse_args()


def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.

Example:

```py
>>> dtype_byte_size(paddle.float32)
4
```
"""
if str(dtype) in {"paddle.bool", "bool"}:
return 1 / 8
if str(dtype) in {"paddle.float8_e4m3fn", "paddle.float8_e5m2", "float8_e4m3fn", "float8_e5m2"}:
return 1
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8


def extract_mtp_weights(input_dir: str) -> dict:
"""
Load all MTP-related weights from safetensors files in input_dir.
Expand Down Expand Up @@ -103,6 +127,18 @@ def save_safetensors(state_dict: dict, output_dir: str):
logger.info(f"Saving shard: {save_path}")
safe_save_file(shard, save_path, metadata={"format": "np"})

# If only one shard is returned, SAFE_WEIGHTS_INDEX_NAME will be null
if len(shards) == 1:
logger.info("Generate index file for single shard")
weight_size = 0
for key, weight in shards["model.safetensors"].items():
weight_size += np.prod(weight.shape) * dtype_byte_size(weight.dtype)

index = {
"metadata": {"total_size": int(weight_size)},
"weight_map": {k: "model.safetensors" for k in shards["model.safetensors"].keys()},
}

index_path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)
with open(index_path, "w", encoding="utf-8") as f:
json.dump(index, f, indent=2)
Expand Down