Skip to content

Commit 804a697

Browse files
fix qwen2 & qwen2_5_omni
1 parent 06378d4 commit 804a697

File tree

3 files changed

+6
-7
lines changed

3 files changed

+6
-7
lines changed

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2388,7 +2388,7 @@ def forward(
23882388
self.rope_deltas = rope_deltas
23892389

23902390
else:
2391-
batch_size, seq_length, _ = inputs_embeds.shape
2391+
batch_size, seq_length = input_ids.shape
23922392
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
23932393
position_ids = torch.arange(seq_length, device=input_ids.device)
23942394
position_ids = position_ids.view(1, -1).expand(batch_size, -1)

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2588,7 +2588,7 @@ def forward(
25882588
self.rope_deltas = rope_deltas
25892589

25902590
else:
2591-
batch_size, seq_length, _ = inputs_embeds.shape
2591+
batch_size, seq_length = input_ids.shape
25922592
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
25932593
position_ids = torch.arange(seq_length, device=input_ids.device)
25942594
position_ids = position_ids.view(1, -1).expand(batch_size, -1)

tests/models/qwen2/test_modeling_qwen2.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,10 @@ def test_model_450m_logits(self):
7878
with torch.no_grad():
7979
out = model(input_ids).logits.float().cpu()
8080
# Expected mean on dim = -1
81-
EXPECTED_MEAN = torch.tensor([[-1.9537, -1.6193, -1.4123, -1.4673, -1.8511, -1.9309, -1.9826, -2.1776]])
81+
EXPECTED_MEAN = torch.tensor([[-2.2121, -1.6335, -1.4816, -1.5035, -1.9110, -1.8979, -1.9682, -2.1980]])
8282
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
8383
# slicing logits[0, 0, 0:30]
84-
EXPECTED_SLICE = torch.tensor([3.2025, 7.1265, 4.6058, 3.6423, 1.6357, 3.9265, 5.1883, 5.8760, 2.7942, 4.4823, 3.2571, 2.1063, 3.4275, 4.2028, 1.9767, 5.2115, 6.6756, 6.3999, 6.0483, 5.7378, 5.6660, 5.2298, 5.4103, 5.1248, 5.4376, 2.4570, 2.6107, 5.4039, 2.8077, 4.7777]) # fmt: skip
85-
print(out[0, 0, :30])
84+
EXPECTED_SLICE = torch.tensor([2.7344, 4.2812, 4.1562, 2.3906, 1.1875, 2.1562, 3.1719, 3.1406, 1.2891, 3.6094, 3.3125, 1.8203, 2.9219, 3.2344, 1.5938, 6.2500, 7.4062, 7.2188, 6.5938, 6.0312, 6.1562, 5.3750, 5.9688, 5.5938, 6.1250, 1.2656, 1.6016, 3.4062, 1.7891, 3.6406]) # fmt: skip
8685
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
8786

8887
del model
@@ -92,7 +91,7 @@ def test_model_450m_logits(self):
9291
@slow
9392
def test_model_450m_generation(self):
9493
EXPECTED_TEXT_COMPLETION = (
95-
"""My favourite condiment is 100% natural, organic and vegan. I love to use it in my cooking and I"""
94+
"""My favourite condiment is 100% natural, organic and vegan. I love to use it in my cooking, but"""
9695
)
9796
prompt = "My favourite condiment is "
9897
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B", use_fast=False)
@@ -161,7 +160,7 @@ def test_model_450m_long_prompt_sdpa(self):
161160
gc.collect()
162161

163162
EXPECTED_TEXT_COMPLETION = (
164-
"My favourite condiment is 100% natural, organic and vegan. I love to use it in my cooking and I"
163+
"My favourite condiment is 100% natural, organic and vegan. I love to use it in my cooking, but"
165164
)
166165
prompt = "My favourite condiment is "
167166
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B", use_fast=False)

0 commit comments

Comments
 (0)