Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,9 @@ def _handle_speculative_decoding(self):
)
if self.max_running_requests is None:
self.max_running_requests = 48
logger.warning(
"Max running requests is reset to 48 for speculative decoding."
)

if self.speculative_algorithm == "EAGLE" and self.enable_beta_spec:
self.disable_overlap_schedule = False
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class TestFile:
],
"per-commit-8-gpu-h200-deepseek-v32": [
TestFile("test_deepseek_v32_basic.py", 275),
TestFile("test_deepseek_v32_mtp.py", 275),
],
"vllm_dependency_test": [
TestFile("quant/test_awq.py", 163),
Expand Down
6 changes: 3 additions & 3 deletions test/srt/test_deepseek_v32_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
DEEPSEEK_V32_MODEL_PATH = "deepseek-ai/DeepSeek-V3.2-Exp"


class TestDeepseekV3Basic(CustomTestCase):
class TestDeepseekV32Basic(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEEPSEEK_V32_MODEL_PATH
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_a_gsm8k(

if is_in_ci():
write_github_step_summary(
f"### test_gsm8k (deepseek-v3)\n" f'{metrics["accuracy"]=:.3f}\n'
f"### test_gsm8k (deepseek-v32)\n" f'{metrics["accuracy"]=:.3f}\n'
)
self.assertGreater(metrics["accuracy"], 0.935)

Expand All @@ -69,7 +69,7 @@ def test_bs_1_speed(self):

if is_in_ci():
write_github_step_summary(
f"### test_bs_1_speed (deepseek-v3)\n" f"{speed=:.2f} token/s\n"
f"### test_bs_1_speed (deepseek-v32)\n" f"{speed=:.2f} token/s\n"
)
self.assertGreater(speed, 50)

Expand Down
105 changes: 105 additions & 0 deletions test/srt/test_deepseek_v32_mtp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import unittest
from types import SimpleNamespace

import requests

from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.send_one import BenchArgs, send_one_prompt
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
write_github_step_summary,
)

FULL_DEEPSEEK_V3_MODEL_PATH = "deepseek-ai/DeepSeek-V3.2-Exp"


class TestDeepseekV32MTP(CustomTestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test class shares a lot of duplicate code with TestDeepseekV32Basic in test/srt/test_deepseek_v32_basic.py. For example, the tearDownClass method and large parts of the test methods (test_a_gsm8k, test_bs_1_speed) are identical or very similar.

To improve maintainability and follow the DRY (Don't Repeat Yourself) principle, consider refactoring the common code into a base test class. The subclasses (TestDeepseekV32Basic and TestDeepseekV32MTP) would then only need to define their specific configurations (e.g., other_args in setUpClass) and any specific assertions.

A possible structure could be:

# In a shared location, e.g., a new base test file
class TestDeepseekV32Base(CustomTestCase):
    model = "deepseek-ai/DeepSeek-V3.2-Exp"
    base_url = DEFAULT_URL_FOR_TEST
    process = None
    other_args = []

    @classmethod
    def setUpClass(cls):
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=cls.other_args,
        )

    @classmethod
    def tearDownClass(cls):
        if cls.process:
            kill_process_tree(cls.process.pid)

    # ... common test logic can be extracted into helper methods ...

This would make the tests cleaner and easier to maintain in the long run.

@classmethod
def setUpClass(cls):
cls.model = FULL_DEEPSEEK_V3_MODEL_PATH
Comment on lines +18 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The constant FULL_DEEPSEEK_V3_MODEL_PATH has the same value as DEEPSEEK_V32_MODEL_PATH in test_deepseek_v32_basic.py. To maintain consistency and avoid confusion, it's better to use the same name. The FULL_ prefix is also redundant.

Ideally, this constant should be defined once in a shared location and imported where needed to adhere to the DRY (Don't Repeat Yourself) principle.

Suggested change
FULL_DEEPSEEK_V3_MODEL_PATH = "deepseek-ai/DeepSeek-V3.2-Exp"
class TestDeepseekV32MTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = FULL_DEEPSEEK_V3_MODEL_PATH
DEEPSEEK_V32_MODEL_PATH = "deepseek-ai/DeepSeek-V3.2-Exp"
class TestDeepseekV32MTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEEPSEEK_V32_MODEL_PATH

cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--trust-remote-code",
"--tp",
"8",
"--dp",
"8",
"--enable-dp-attention",
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"4",
"--mem-frac",
"0.7",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_a_gsm8k(
self,
): # Append an "a" to make this test run first (alphabetically) to warm up the server
requests.get(self.base_url + "/flush_cache")

args = SimpleNamespace(
num_shots=8,
data_path=None,
num_questions=1400,
parallel=1400,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")

server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(f"{avg_spec_accept_length=}")

if is_in_ci():
write_github_step_summary(
f"### test_gsm8k (deepseek-v32 mtp)\n"
f'{metrics["accuracy"]=:.3f}\n'
f"{avg_spec_accept_length=:.2f}\n"
)
self.assertGreater(metrics["accuracy"], 0.935)
self.assertGreater(avg_spec_accept_length, 2.9)

def test_bs_1_speed(self):
args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048)
acc_length, speed = send_one_prompt(args)

print(f"{acc_length=:.2f} {speed=:.2f}")

if is_in_ci():
write_github_step_summary(
f"### test_bs_1_speed (deepseek-v32 mtp)\n"
f"{acc_length=:.2f}\n"
f"{speed=:.2f} token/s\n"
)

self.assertGreater(acc_length, 2.9)
self.assertGreater(speed, 75)


if __name__ == "__main__":
unittest.main()
Loading