Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
27da1cd
DeepSeek V4 support eager
grimoire Apr 24, 2026
7e0e304
add template
grimoire Apr 24, 2026
ed59863
remove compressed_cache_engine.py
grimoire Apr 25, 2026
ab63081
fix fused moe
grimoire Apr 26, 2026
1ef28b6
support cudagraph
grimoire Apr 26, 2026
daf655f
fix cudagraph phase1/2
grimoire Apr 26, 2026
973ebf1
agent suck
grimoire Apr 27, 2026
486002e
wtf
grimoire Apr 27, 2026
41bb9f9
fix
grimoire Apr 27, 2026
2d57ab8
remove quantlinear
grimoire Apr 27, 2026
016e1dc
remove debug
grimoire Apr 27, 2026
d24ca01
remove key
grimoire Apr 27, 2026
92b3008
statecache
grimoire Apr 28, 2026
40d157a
window as state
grimoire Apr 28, 2026
d864604
remove
grimoire Apr 28, 2026
5b6cd60
better fill sliding window
grimoire Apr 28, 2026
740eb3a
use flashmla
grimoire Apr 28, 2026
a1f98f6
new start
grimoire Apr 28, 2026
3a0d6c3
add kernels
grimoire Apr 29, 2026
da6092e
fix layout
grimoire Apr 30, 2026
b8ad3c3
fix
grimoire Apr 30, 2026
2129e11
fix
grimoire Apr 30, 2026
0e9b0c0
newnew
grimoire Apr 30, 2026
2877b8d
fix
grimoire May 1, 2026
dd36198
opt indexer
grimoire May 1, 2026
615bf1e
update compress kernel
grimoire May 1, 2026
45c4045
optimize attn forward
grimoire May 1, 2026
cd86685
sparse attn
grimoire May 2, 2026
e23d6d0
fp8 cache
grimoire May 2, 2026
e8746f5
mla
grimoire May 2, 2026
3a11ab6
remove batch loop
grimoire May 2, 2026
1161777
rotary
grimoire May 3, 2026
95286cb
enable cudagraph
grimoire May 3, 2026
374121a
no sync kernel
grimoire May 4, 2026
cd6d015
no bsz loop
grimoire May 4, 2026
c727d82
cudagraph fix
grimoire May 4, 2026
4bc3a14
fix indexer
grimoire May 6, 2026
c818f44
fp4 moe hopper
grimoire May 6, 2026
9fc223a
fix warmup
grimoire May 6, 2026
46eb94a
fix apply rotary
grimoire May 6, 2026
e319195
fix
grimoire May 6, 2026
89e8e49
package
grimoire May 6, 2026
e892eb6
opt
grimoire May 6, 2026
a51f75f
v4 indexer opt
grimoire May 6, 2026
61cf2ea
indexer
grimoire May 6, 2026
8e60e8b
wrap compressor
grimoire May 6, 2026
2bc1154
pack attn
grimoire May 7, 2026
4479987
add op;update state cache
grimoire May 7, 2026
5d36638
update attn meta once
grimoire May 8, 2026
4ed038f
refactor
grimoire May 8, 2026
4f64822
fuse kernel
grimoire May 8, 2026
54007b2
fix little bugs
grimoire May 8, 2026
0b0c1c8
add skip layers for debug
grimoire May 9, 2026
625356d
fix
grimoire May 9, 2026
f654c54
refactor v4
grimoire May 9, 2026
3cf0b07
fix kernel
grimoire May 10, 2026
a0ec686
opt indexer
grimoire May 11, 2026
89d89c2
merge main
grimoire May 11, 2026
43543d8
force bitonic topk
grimoire May 11, 2026
76056cb
ep
grimoire May 12, 2026
616849d
auto block size
grimoire May 12, 2026
7dcb18d
fix
grimoire May 12, 2026
c7ac9ab
opt
grimoire May 12, 2026
37f6ca1
opt moe
grimoire May 13, 2026
5f5f941
optimize topk
grimoire May 13, 2026
3fb9f0c
optimize
grimoire May 13, 2026
7abd076
opt kernel
grimoire May 13, 2026
a5dba47
opt compressor
grimoire May 15, 2026
6adc8c9
decode attn meta once
grimoire May 15, 2026
e0ca8a8
optimize prefill
grimoire May 15, 2026
fea204e
optimize
grimoire May 15, 2026
d7d47b2
fix
grimoire May 18, 2026
7746461
update template
grimoire May 19, 2026
1ade47f
no tp indexer
grimoire May 19, 2026
8a769b2
opt prefix pos
grimoire May 19, 2026
374a0ca
fix lint
grimoire May 19, 2026
4c9ed5d
Merge branch 'main' into dsv4
grimoire May 19, 2026
4735336
fix config_from_pretrained
grimoire May 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# macOS
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
10 changes: 3 additions & 7 deletions lmdeploy/archs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Literal

from transformers import AutoConfig

from .messages import PytorchEngineConfig, TurbomindEngineConfig
from .utils import get_logger

Expand Down Expand Up @@ -152,11 +150,9 @@ def get_model_arch(model_path: str, trust_remote_code: bool = False):
Args:
model_path(str): the model path
"""
try:
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
except Exception as e: # noqa
from transformers import PretrainedConfig
cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
from lmdeploy.hf_configs import config_from_pretrained

cfg = config_from_pretrained(model_path, trust_remote_code=trust_remote_code)

_cfg = cfg.to_dict()
if _cfg.get('architectures', None):
Expand Down
39 changes: 39 additions & 0 deletions lmdeploy/hf_configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import lru_cache

from transformers import AutoConfig

from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')


@lru_cache
def register_config(model_type: str):
if model_type == 'deepseek_v32':
from .configuration_deepseek_v32 import DeepseekV32Config
AutoConfig.register(DeepseekV32Config.model_type, DeepseekV32Config)
elif model_type == 'deepseek_v4':
from .configuration_deepseek_v4 import DeepseekV4Config
AutoConfig.register(DeepseekV4Config.model_type, DeepseekV4Config)
else:
logger.debug(f'Can not register config for model_type: {model_type}')


def config_from_pretrained(pretrained_model_name_or_path: str, **kwargs):
try:
return AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
except Exception as e:
logger.debug(f'AutoConfig.from_pretrained failed: {e}, try register config manually.')
# some models do not provide auto map for config
from transformers import PretrainedConfig
trust_remote_code = kwargs.pop('trust_remote_code', None)
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
model_type = config_dict.get('model_type', None)
if trust_remote_code is not None:
kwargs['trust_remote_code'] = trust_remote_code
register_config(model_type)
try:
return AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
except Exception as e:
return PretrainedConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
13 changes: 13 additions & 0 deletions lmdeploy/hf_configs/configuration_deepseek_v32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.

from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config


class DeepseekV32Config(DeepseekV3Config):
model_type = 'deepseek_v32'

def __init__(self, index_head_dim=128, index_n_heads=64, index_topk=2048, **kwargs):
super().__init__(**kwargs)
self.index_head_dim = index_head_dim
self.index_n_heads = index_n_heads
self.index_topk = index_topk
100 changes: 100 additions & 0 deletions lmdeploy/hf_configs/configuration_deepseek_v4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) OpenMMLab. All rights reserved.
from transformers.configuration_utils import PretrainedConfig


class DeepseekV4Config(PretrainedConfig):
model_type = 'deepseek_v4'

def __init__(self,
architectures=None,
attention_bias=False,
attention_dropout=0.0,
bos_token_id=0,
eos_token_id=1,
hc_eps=1e-6,
hc_mult=4,
hc_sinkhorn_iters=20,
head_dim=512,
hidden_act='silu',
hidden_size=4096,
index_head_dim=128,
index_n_heads=64,
index_topk=512,
initializer_range=0.02,
max_position_embeddings=1048576,
moe_intermediate_size=2048,
n_routed_experts=256,
n_shared_experts=1,
norm_topk_prob=True,
num_attention_heads=64,
num_experts_per_tok=6,
num_hidden_layers=43,
num_hash_layers=3,
num_key_value_heads=1,
num_nextn_predict_layers=1,
o_groups=8,
o_lora_rank=1024,
q_lora_rank=1024,
qk_rope_head_dim=64,
quantization_config=None,
rms_norm_eps=1e-6,
rope_scaling=None,
rope_theta=10000,
routed_scaling_factor=1.5,
scoring_func='sqrtsoftplus',
sliding_window=128,
swiglu_limit=10.0,
tie_word_embeddings=False,
topk_method='noaux_tc',
torch_dtype='bfloat16',
use_cache=True,
vocab_size=129280,
compress_rope_theta=160000,
compress_ratios=None,
**kwargs):
super().__init__(bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
torch_dtype=torch_dtype,
**kwargs)
self.architectures = architectures or ['DeepseekV4ForCausalLM']
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hc_eps = hc_eps
self.hc_mult = hc_mult
self.hc_sinkhorn_iters = hc_sinkhorn_iters
self.head_dim = head_dim
self.hidden_act = hidden_act
self.hidden_size = hidden_size
self.index_head_dim = index_head_dim
self.index_n_heads = index_n_heads
self.index_topk = index_topk
self.initializer_range = initializer_range
self.max_position_embeddings = max_position_embeddings
self.moe_intermediate_size = moe_intermediate_size
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.norm_topk_prob = norm_topk_prob
self.num_attention_heads = num_attention_heads
self.num_experts_per_tok = num_experts_per_tok
self.num_hidden_layers = num_hidden_layers
self.num_hash_layers = num_hash_layers
self.num_key_value_heads = num_key_value_heads
self.num_nextn_predict_layers = num_nextn_predict_layers
self.o_groups = o_groups
self.o_lora_rank = o_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.quantization_config = quantization_config
self.rms_norm_eps = rms_norm_eps
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.routed_scaling_factor = routed_scaling_factor
self.scoring_func = scoring_func
self.sliding_window = sliding_window
self.swiglu_limit = swiglu_limit
self.topk_method = topk_method
self.use_cache = use_cache
self.vocab_size = vocab_size
self.compress_rope_theta = compress_rope_theta
self.compress_ratios = compress_ratios or [0, 0]
68 changes: 68 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,74 @@ def match(cls, model_path: str, **kwargs) -> str | None:
return 'deepseek-vl2'


@MODELS.register_module(name=['deepseek-v4'])
class DeepseekV4ChatTemplate(BaseChatTemplate):
"""Chat template for DeepSeek-V4.

Produces prompts that match the official DeepSeek-V4 encoding format:
- <|begin▁of▁sentence|> (BOS)
- <|User|> (user role prefix)
- content
- <|Assistant|> (assistant role prefix)
- <think> (thinking token - no content for user messages in chat mode)
- <|end▁of▁sentence|> (EOS)
"""

def __init__(
self,
meta_instruction='',
eosys='',
user='<|User|>',
eoh='',
assistant='<|Assistant|>',
eoa='<|end▁of▁sentence|>',
**kwargs):
super().__init__(meta_instruction=meta_instruction,
eosys=eosys,
user=user,
eoh=eoh,
assistant=assistant,
eoa=eoa,
**kwargs)

def messages2prompt(self, messages, sequence_start=True, **kwargs):
if isinstance(messages, str):
messages = [{'role': 'user', 'content': messages}]

prompt = '<|begin▁of▁sentence|>'

for i, msg in enumerate(messages):
role = msg.get('role')
content = msg.get('content', '')
is_last = i == len(messages) - 1

if role == 'user':
prompt += '<|User|>'
prompt += content
prompt += '<|Assistant|>'
if is_last:
prompt += '<think>'
elif role == 'assistant':
prompt += '</think>'
prompt += content
prompt += '<|end▁of▁sentence|>'
elif role == 'system':
prompt += content

return prompt

def get_prompt(self, prompt, sequence_start=True, **kwargs):
return self.messages2prompt([{'role': 'user', 'content': prompt}], sequence_start, **kwargs)

@classmethod
def match(cls, model_path: str, trust_remote_code: bool = False) -> str | None:
"""Return the model_name that was registered to MODELS."""
path = model_path.lower()
if 'deepseek' in path and 'v4' in path.lower():
return 'deepseek-v4'
return None


@MODELS.register_module(name=['llava-chatml'])
class ChatmlDirect(BaseChatTemplate):

Expand Down
8 changes: 7 additions & 1 deletion lmdeploy/pytorch/backends/apply_rotary_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ class ApplyRotaryEmbImpl(ABC):
"""Apply rotary embedding implementation."""

@abstractmethod
def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):
def forward(self,
query: Tensor,
key: Tensor,
cos: Tensor,
sin: Tensor,
inplace: bool = True,
complex_mode: bool = False):
"""forward."""
raise NotImplementedError

Expand Down
55 changes: 55 additions & 0 deletions lmdeploy/pytorch/backends/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,61 @@ class AttentionMetadata:
quant_policy: QuantPolicy = QuantPolicy.NONE


@dataclass
class V4AttentionMetadata:
"""DeepSeek V4 attention metadata base class.

Built once per step from attn_metadata + step_ctx, then passed through all V4 sub-modules (Attention, Compressor,
Indexer). Backends should subclass this to add their own pre-computed fields and override ``from_step_context``.
"""

is_decoding: bool
# [bsz, 1, topk] logical compressed KV indices (converted to physical by V4IndicesUpdater)
indices_in_kvcache: torch.Tensor = None
topk_length: torch.Tensor = None # [bsz] int32
extra_indices_in_kvcache: torch.Tensor = None # [bsz, 1, extra_topk] ring-buffer positions
extra_topk_length: torch.Tensor = None # [bsz] int32
# Sequence-length metadata (from attn_metadata, pre-extracted once)
block_offsets: torch.Tensor = None
cu_q_seqlens: torch.Tensor = None
kv_seqlens: torch.Tensor = None
q_seqlens: torch.Tensor = None
max_kv_seqlen: int = None
max_q_seqlen: int = None
block_size: int = 0
cu_seqlens_k: torch.Tensor = None
sum_kv_seqlen: int = None
start_pos: torch.Tensor = None # [bsz] long

@classmethod
def from_step_context(cls, attn_metadata, step_ctx, **kwargs) -> 'V4AttentionMetadata':
"""Build V4AttentionMetadata from the scheduler's attn_metadata and
step_ctx.

Subclasses can accept additional keyword arguments for backend- specific pre-computation.
"""
is_decoding = attn_metadata.is_decoding
cache_config = step_ctx.cache_config
max_kv_seqlen = (cache_config.block_size * cache_config.num_gpu_blocks
if is_decoding else step_ctx.max_kv_seqlen)
kv_seqlens = attn_metadata.kv_seqlens
q_seqlens = attn_metadata.q_seqlens

return cls(
is_decoding=is_decoding,
block_offsets=attn_metadata.block_offsets,
cu_q_seqlens=attn_metadata.cu_seqlens_q,
kv_seqlens=kv_seqlens,
q_seqlens=q_seqlens,
max_kv_seqlen=max_kv_seqlen,
max_q_seqlen=step_ctx.max_q_seqlen,
block_size=cache_config.block_size,
sum_kv_seqlen=step_ctx.sum_kv_seqlen,
cu_seqlens_k=attn_metadata.cu_seqlens_k,
start_pos=(kv_seqlens.to(torch.long) - q_seqlens.to(torch.long)),
)


T = TypeVar('T', bound=AttentionMetadata)


Expand Down
15 changes: 15 additions & 0 deletions lmdeploy/pytorch/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ class OpType(Enum):
FusedMoEW8A8 = auto()
LinearBlockedF8 = auto()
FusedMoEBlockedF8 = auto()
FusedMoEV4FP4 = auto()
NSAIndexFP8 = auto()
V4Attention = auto()
V4Indexer = auto()
V4Compressor = auto()
HcSplitSinkhorn = auto()
Embedding = auto()

# MoE router
Expand Down Expand Up @@ -62,6 +67,16 @@ def get_attention_metadata_cls():
"""Get attention metadata class."""
raise NotImplementedError

@staticmethod
def get_v4_attention_metadata_cls():
"""Get V4 attention metadata class.

Returns ``V4AttentionMetadata`` by default; backends with V4-specific
pre-computation should override this to return their subclass.
"""
from lmdeploy.pytorch.backends.attention import V4AttentionMetadata
return V4AttentionMetadata

@staticmethod
@abstractmethod
def get_k_block_shape(
Expand Down
Loading
Loading