@@ -89,7 +89,7 @@ class Qwen3NextDynamicCache:
8989 cache (which has a constant shape regardless of seq_len).
9090
9191 This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
92- and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
92+ and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
9393 For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
9494 while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
9595 For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
@@ -108,7 +108,7 @@ def __init__(self, config: Qwen3NextConfig, batch_size, dtype=torch.float16, dev
108108 self .recurrent_states = []
109109 self .transformer_layers = []
110110 for i in range (config .num_hidden_layers ):
111- # NOTE: only use mamba2 and full attention now! need to change future for more blocks.
111+ # NOTE: only use gated deltanet and full attention now! need to change future for more blocks.
112112 if self .layer_types [i ] == "linear_attention" :
113113 self .conv_states += [
114114 torch .zeros (
@@ -1196,7 +1196,7 @@ def forward(
11961196 input_ids : Optional [torch .LongTensor ] = None ,
11971197 attention_mask : Optional [torch .Tensor ] = None ,
11981198 position_ids : Optional [torch .LongTensor ] = None ,
1199- past_key_values : Optional [Cache ] = None ,
1199+ past_key_values : Optional [Qwen3NextDynamicCache ] = None ,
12001200 inputs_embeds : Optional [torch .FloatTensor ] = None ,
12011201 labels : Optional [torch .LongTensor ] = None ,
12021202 use_cache : Optional [bool ] = None ,
0 commit comments