Skip to content

Commit fd6affc

Browse files
bozheng-hitCyrilvallez
authored andcommitted
fix
1 parent 4b742f9 commit fd6affc

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

src/transformers/models/qwen3_next/modeling_qwen3_next.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class Qwen3NextDynamicCache:
8989
cache (which has a constant shape regardless of seq_len).
9090
9191
This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
92-
and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
92+
and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
9393
For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
9494
while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
9595
For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
@@ -108,7 +108,7 @@ def __init__(self, config: Qwen3NextConfig, batch_size, dtype=torch.float16, dev
108108
self.recurrent_states = []
109109
self.transformer_layers = []
110110
for i in range(config.num_hidden_layers):
111-
# NOTE: only use mamba2 and full attention now! need to change future for more blocks.
111+
# NOTE: only use gated deltanet and full attention now! need to change future for more blocks.
112112
if self.layer_types[i] == "linear_attention":
113113
self.conv_states += [
114114
torch.zeros(
@@ -1196,7 +1196,7 @@ def forward(
11961196
input_ids: Optional[torch.LongTensor] = None,
11971197
attention_mask: Optional[torch.Tensor] = None,
11981198
position_ids: Optional[torch.LongTensor] = None,
1199-
past_key_values: Optional[Cache] = None,
1199+
past_key_values: Optional[Qwen3NextDynamicCache] = None,
12001200
inputs_embeds: Optional[torch.FloatTensor] = None,
12011201
labels: Optional[torch.LongTensor] = None,
12021202
use_cache: Optional[bool] = None,

src/transformers/models/qwen3_next/modular_qwen3_next.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class Qwen3NextDynamicCache:
9797
cache (which has a constant shape regardless of seq_len).
9898
9999
This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
100-
and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
100+
and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
101101
For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
102102
while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
103103
For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
@@ -116,7 +116,7 @@ def __init__(self, config: Qwen3NextConfig, batch_size, dtype=torch.float16, dev
116116
self.recurrent_states = []
117117
self.transformer_layers = []
118118
for i in range(config.num_hidden_layers):
119-
# NOTE: only use mamba2 and full attention now! need to change future for more blocks.
119+
# NOTE: only use gated deltanet and full attention now! need to change future for more blocks.
120120
if self.layer_types[i] == "linear_attention":
121121
self.conv_states += [
122122
torch.zeros(

tests/models/qwen3_next/test_modeling_qwen3_next.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class Qwen3NextModelTest(CausalLMModelTest, unittest.TestCase):
9696
model_tester_class = Qwen3NextModelTester
9797

9898
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
99-
"Qwen3-Next has a special Cache as it alternates with Mamba layers"
99+
"Qwen3-Next has a special Cache as it alternates with gated deltanet layers"
100100
self.assertIsInstance(decoder_past_key_values, Qwen3NextDynamicCache)
101101

102102
# (batch, head, seq_length, head_features)
@@ -119,7 +119,7 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value
119119

120120
@pytest.mark.generate
121121
def test_past_key_values_format(self):
122-
"Needs to be overwritten as Qwen3-Next alternates between attention layers and mamba layers."
122+
"Needs to be overwritten as Qwen3-Next alternates between attention layers and gated deltanet layers."
123123
for model_class in self.all_generative_model_classes:
124124
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
125125

@@ -150,7 +150,7 @@ def test_past_key_values_format(self):
150150
self.assertEqual(self_attention_layer_values.shape, default_self_attention_shape)
151151

152152
def test_attention_outputs(self):
153-
"Needs to be overwritten as Qwen3-Next alternates between attention layers and mamba layers."
153+
"Needs to be overwritten as Qwen3-Next alternates between attention layers and gated deltanet layers."
154154
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
155155
config.return_dict = True
156156
# force eager attention to support output attentions

0 commit comments

Comments
 (0)