Skip to content

Commit c86b335

Browse files
authored
【Hackathon 9th No.78】add test_chat.py (#3958)
1 parent 06f4b49 commit c86b335

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

tests/entrypoints/test_chat.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import os
18+
import unittest
19+
import weakref
20+
21+
from fastdeploy.entrypoints.llm import LLM
22+
23+
MODEL_NAME = os.getenv("MODEL_PATH") + "/ERNIE-4.5-0.3B-Paddle"
24+
25+
26+
class TestChat(unittest.TestCase):
27+
"""Test case for chat functionality"""
28+
29+
PROMPTS = [
30+
[{"content": "The color of tomato is ", "role": "user"}],
31+
[{"content": "The equation 2+3= ", "role": "user"}],
32+
[{"content": "The equation 4-1= ", "role": "user"}],
33+
[{"content": "PaddlePaddle is ", "role": "user"}],
34+
]
35+
36+
@classmethod
37+
def setUpClass(cls):
38+
try:
39+
llm = LLM(
40+
model=MODEL_NAME,
41+
max_num_batched_tokens=4096,
42+
tensor_parallel_size=1,
43+
engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT")),
44+
cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT")),
45+
)
46+
cls.llm = weakref.proxy(llm)
47+
except Exception as e:
48+
print(f"Setting up LLM failed: {e}")
49+
raise unittest.SkipTest(f"LLM initialization failed: {e}")
50+
51+
@classmethod
52+
def tearDownClass(cls):
53+
"""Clean up after all tests have run"""
54+
if hasattr(cls, "llm"):
55+
del cls.llm
56+
57+
def test_chat(self):
58+
outputs = self.llm.chat(messages=self.PROMPTS, sampling_params=None)
59+
self.assertEqual(len(self.PROMPTS), len(outputs))
60+
61+
62+
if __name__ == "__main__":
63+
unittest.main()

tests/entrypoints/test_generation.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17+
import copy
1718
import os
1819
import unittest
1920
import weakref
@@ -120,6 +121,48 @@ def test_multiple_sampling_params(self):
120121
outputs = self.llm.generate(prompts=self.PROMPTS, sampling_params=None)
121122
self.assertEqual(len(self.PROMPTS), len(outputs))
122123

124+
def test_consistency_single_prompt_tokens_chat(self):
125+
"""Test consistency between different prompt input formats"""
126+
sampling_params = SamplingParams(temperature=1.0, top_p=0.0)
127+
128+
for prompt_token_ids in self.TOKEN_IDS:
129+
with self.subTest(prompt_token_ids=prompt_token_ids):
130+
output1 = self.llm.chat(messages=[prompt_token_ids], sampling_params=sampling_params)
131+
output2 = self.llm.chat(
132+
[{"prompt": "", "prompt_token_ids": prompt_token_ids}], sampling_params=sampling_params
133+
)
134+
self.assert_outputs_equal(output1, output2)
135+
136+
def test_multiple_sampling_params_chat(self):
137+
"""Test multiple sampling parameters combinations"""
138+
sampling_params = [
139+
SamplingParams(temperature=0.01, top_p=0.95),
140+
SamplingParams(temperature=0.3, top_p=0.95),
141+
SamplingParams(temperature=0.7, top_p=0.95),
142+
SamplingParams(temperature=0.99, top_p=0.95),
143+
]
144+
145+
prompts = copy.copy(self.PROMPTS)
146+
# Multiple SamplingParams should be matched with each prompt
147+
outputs = self.llm.chat(messages=prompts, sampling_params=sampling_params)
148+
self.assertEqual(len(self.PROMPTS), len(outputs))
149+
150+
prompts = copy.copy(self.PROMPTS)
151+
# Exception raised if size mismatch
152+
with self.assertRaises(ValueError):
153+
self.llm.chat(messages=prompts, sampling_params=sampling_params[:3])
154+
155+
prompts = copy.copy(self.PROMPTS)
156+
# Single SamplingParams should be applied to every prompt
157+
single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
158+
outputs = self.llm.chat(messages=prompts, sampling_params=single_sampling_params)
159+
self.assertEqual(len(self.PROMPTS), len(outputs))
160+
161+
prompts = copy.copy(self.PROMPTS)
162+
# sampling_params is None, default params should be applied
163+
outputs = self.llm.chat(messages=prompts, sampling_params=None)
164+
self.assertEqual(len(self.PROMPTS), len(outputs))
165+
123166

124167
if __name__ == "__main__":
125168
unittest.main()

0 commit comments

Comments
 (0)