Skip to content

Commit 6149601

Browse files
committed
fix
1 parent 9d9c4d2 commit 6149601

File tree

1 file changed

+46
-22
lines changed

1 file changed

+46
-22
lines changed

tests/models/phimoe/test_modeling_phimoe.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
"""Testing suite for the PyTorch PhiMoE model."""
1616

17+
import copy
1718
import unittest
1819

1920
from parameterized import parameterized
2021

2122
from transformers import PhimoeConfig, StaticCache, is_torch_available
2223
from transformers.testing_utils import (
24+
cleanup,
2325
require_torch,
2426
slow,
2527
torch_device,
@@ -130,31 +132,47 @@ def test_model_rope_scaling_from_config(self, scaling_type):
130132
@slow
131133
@require_torch
132134
class PhimoeIntegrationTest(unittest.TestCase):
133-
def test_model_phimoe_instruct_logits(self):
134-
input_ids = {
135-
"input_ids": torch.tensor(
136-
[[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device
135+
model = None
136+
137+
@classmethod
138+
def get_model(cls):
139+
if cls.model is None:
140+
cls.model = PhimoeForCausalLM.from_pretrained(
141+
"microsoft/Phi-3.5-MoE-instruct", dtype="auto", device_map="auto"
137142
)
138-
}
143+
return cls.model
144+
145+
@classmethod
146+
def tearDownClass(cls):
147+
del cls.model
148+
cleanup(torch_device, gc_collect=True)
149+
150+
def setUp(self):
151+
cleanup(torch_device, gc_collect=True)
152+
153+
def tearDown(self):
154+
cleanup(torch_device, gc_collect=True)
155+
156+
def test_model_phimoe_instruct_logits(self):
157+
input_ids = {"input_ids": torch.tensor([[1212, 318, 281, 1672]], dtype=torch.long, device=torch_device)}
139158

140-
model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct").to(torch_device)
159+
model = self.get_model()
141160
model.eval()
142161

143-
output = model(**input_ids).logits
162+
with torch.no_grad():
163+
output = model(**input_ids).logits
144164

145-
EXPECTED_OUTPUT = torch.tensor([[-3.5312, -2.5000, -1.2734, 0.3555, -0.7578, -0.4727, 0.5977, -0.4316,
146-
0.2256, -1.2188, -1.6797, 0.9961, 3.7656, 11.3125, -1.3828, -4.8438,
147-
-5.7500, -1.9375, 0.7227, -0.3438, -0.2100, -0.4277, -0.0444, -0.5352,
148-
-0.6406, -0.1016, -0.4258, -1.0234, 0.4297, -0.6250],
149-
[-0.9883, 0.1455, -0.4902, 2.3594, 0.7031, 3.1406, 0.4375, 0.2559,
150-
0.6172, -2.1094, -1.3359, 2.5938, 4.9062, 10.8125, -0.1094, 1.5781,
151-
-4.9375, 0.7148, -0.0972, 1.7656, -0.0801, 0.2217, 0.1875, -0.4629,
152-
1.5781, 0.3535, 0.0874, 0.6836, -0.0518, -1.2969]]).to(torch_device) # fmt: skip
165+
EXPECTED_OUTPUT = torch.tensor(
166+
[
167+
[-3.4844, -2.4531, -1.1719, 0.6055, -0.4922, -0.1001, 0.8086, -0.2422, 0.3477, -1.0078],
168+
[-0.9766, 0.1631, -0.5508, 2.3594, 0.7031, 3.1719, 0.4141, 0.2305, 0.6055, -2.1250],
169+
]
170+
).to(device=torch_device, dtype=output.dtype) # fmt: skip
153171

154-
torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-4, atol=1e-4)
172+
torch.testing.assert_close(output[0, :2, :10], EXPECTED_OUTPUT, rtol=1e-4, atol=1e-4)
155173

156174
def test_phimoe_instruct_generation(self):
157-
model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
175+
model = self.get_model()
158176
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
159177

160178
messages = [
@@ -166,17 +184,22 @@ def test_phimoe_instruct_generation(self):
166184
]
167185
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
168186

169-
outputs = model.generate(inputs, max_new_tokens=32)
187+
outputs = model.generate(inputs, max_new_tokens=10)
170188
output_text = tokenizer.batch_decode(outputs)
171189

172190
EXPECTED_OUTPUT = [
173-
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits are both delicious and nutritious fruits that can be combined in various ways to create tast"
191+
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonf",
174192
]
175193

176194
self.assertListEqual(output_text, EXPECTED_OUTPUT)
177195

178196
def test_phimoe_instruct_with_static_cache(self):
179-
model = PhimoeForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
197+
model = self.get_model()
198+
# Can't run with the real checkpoint, even if offloaded. Let's just use a tiny dummy one
199+
config = copy.deepcopy(model.config)
200+
config.num_hidden_layers = 2
201+
torch.manual_seed(42)
202+
model = type(model)(config)
180203
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
181204

182205
messages = [
@@ -188,12 +211,13 @@ def test_phimoe_instruct_with_static_cache(self):
188211
]
189212
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
190213

191-
response_tokens = PhimoeMiniWithStaticCache.generate(model, inputs, 64)
214+
response_tokens = PhimoeMiniWithStaticCache.generate(model, inputs, max_seq_len=10)
192215

193216
output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))
194217

218+
# This is dummy outputs. We actually check if it could run with static cache, not the output quality.
195219
EXPECTED_OUTPUT = [
196-
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits are both delicious and nutritious fruits that can"
220+
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|>ington"
197221
]
198222

199223
self.assertListEqual(output_text, EXPECTED_OUTPUT)

0 commit comments

Comments
 (0)