Skip to content

Commit 7d0722e

Browse files
mzusmananko-intel
authored andcommitted
[Bugfix][Mamba] Fix Multistep on Mamba-like models (vllm-project#10705)
Signed-off-by: mzusman <mor.zusmann@gmail.com>
1 parent 7ea613e commit 7d0722e

File tree

4 files changed

+84
-4
lines changed

4 files changed

+84
-4
lines changed

tests/models/decoder_only/language/test_jamba.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,44 @@ def test_state_cleanup(
275275
"could be related to finished_requests_ids")
276276

277277

278+
@pytest.mark.parametrize("model", MODELS)
279+
@pytest.mark.parametrize("dtype", ["float"])
280+
def test_multistep(
281+
vllm_runner,
282+
model: str,
283+
dtype: str,
284+
example_prompts,
285+
) -> None:
286+
# This test is verifying that multistep works correctly
287+
#on mamba-like models
288+
with vllm_runner(model, num_scheduler_steps=8,
289+
max_num_seqs=2) as vllm_model:
290+
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
291+
292+
293+
@pytest.mark.parametrize("model", MODELS)
294+
@pytest.mark.parametrize("dtype", ["float"])
295+
@pytest.mark.parametrize("max_tokens", [64])
296+
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
297+
max_tokens: int, example_prompts) -> None:
298+
with vllm_runner(model, num_scheduler_steps=8,
299+
max_num_seqs=2) as vllm_model:
300+
vllm_outputs_multistep = vllm_model.generate_greedy(
301+
example_prompts, max_tokens)
302+
303+
with vllm_runner(model, num_scheduler_steps=1,
304+
max_num_seqs=2) as vllm_model:
305+
vllm_outputs_single_step = vllm_model.generate_greedy(
306+
example_prompts, max_tokens)
307+
308+
check_outputs_equal(
309+
outputs_0_lst=vllm_outputs_multistep,
310+
outputs_1_lst=vllm_outputs_single_step,
311+
name_0="vllm_outputs_multistep",
312+
name_1="vllm_outputs_single_step",
313+
)
314+
315+
278316
@multi_gpu_test(num_gpus=2)
279317
@pytest.mark.parametrize("model", MODELS)
280318
@pytest.mark.parametrize("dtype", ["float"])

tests/models/decoder_only/language/test_mamba.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,39 @@ def test_state_cleanup(
283283
except ValueError:
284284
pytest.fail("Mamba inner state wasn't cleaned up between states, "
285285
"could be related to finished_requests_ids")
286+
287+
288+
@pytest.mark.parametrize("model", MODELS)
289+
@pytest.mark.parametrize("dtype", ["float"])
290+
def test_multistep(
291+
vllm_runner,
292+
model: str,
293+
dtype: str,
294+
example_prompts,
295+
) -> None:
296+
with vllm_runner(model, num_scheduler_steps=8,
297+
max_num_seqs=2) as vllm_model:
298+
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
299+
300+
301+
@pytest.mark.parametrize("model", MODELS)
302+
@pytest.mark.parametrize("dtype", ["float"])
303+
@pytest.mark.parametrize("max_tokens", [64])
304+
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
305+
max_tokens: int, example_prompts) -> None:
306+
with vllm_runner(model, num_scheduler_steps=8,
307+
max_num_seqs=2) as vllm_model:
308+
vllm_outputs_multistep = vllm_model.generate_greedy(
309+
example_prompts, max_tokens)
310+
311+
with vllm_runner(model, num_scheduler_steps=1,
312+
max_num_seqs=2) as vllm_model:
313+
vllm_outputs_single_step = vllm_model.generate_greedy(
314+
example_prompts, max_tokens)
315+
316+
check_outputs_equal(
317+
outputs_0_lst=vllm_outputs_multistep,
318+
outputs_1_lst=vllm_outputs_single_step,
319+
name_0="vllm_outputs_multistep",
320+
name_1="vllm_outputs_single_step",
321+
)

vllm/engine/async_llm_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ async def step_async(
301301
ctx.seq_group_metadata_list = seq_group_metadata_list
302302
ctx.scheduler_outputs = scheduler_outputs
303303

304+
finished_requests_ids = self.scheduler[
305+
virtual_engine].get_and_reset_finished_requests_ids()
306+
304307
# Maybe switch from async mode to sync mode
305308
if not allow_async_output_proc and len(ctx.output_queue) > 0:
306309
self._process_model_outputs(ctx=ctx)
@@ -312,13 +315,13 @@ async def step_async(
312315
self._cache_scheduler_outputs_for_multi_step(
313316
virtual_engine, seq_group_metadata_list, scheduler_outputs,
314317
allow_async_output_proc)
318+
else:
319+
finished_requests_ids = list()
315320

316321
assert seq_group_metadata_list is not None
317322
assert scheduler_outputs is not None
318323

319324
if not scheduler_outputs.is_empty():
320-
finished_requests_ids = self.scheduler[
321-
virtual_engine].get_and_reset_finished_requests_ids()
322325

323326
# Check if we have a cached last_output from the previous iteration.
324327
# For supporting PP this is probably the best way to pass the

vllm/engine/llm_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,6 +1403,9 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
14031403
ctx.seq_group_metadata_list = seq_group_metadata_list
14041404
ctx.scheduler_outputs = scheduler_outputs
14051405

1406+
finished_requests_ids = self.scheduler[
1407+
virtual_engine].get_and_reset_finished_requests_ids()
1408+
14061409
# Maybe switch from async mode to sync mode
14071410
if not allow_async_output_proc and len(ctx.output_queue) > 0:
14081411
self._process_model_outputs(ctx=ctx)
@@ -1414,13 +1417,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
14141417
self._cache_scheduler_outputs_for_multi_step(
14151418
virtual_engine, seq_group_metadata_list, scheduler_outputs,
14161419
allow_async_output_proc)
1420+
else:
1421+
finished_requests_ids = list()
14171422

14181423
assert seq_group_metadata_list is not None
14191424
assert scheduler_outputs is not None
14201425

14211426
if not scheduler_outputs.is_empty():
1422-
finished_requests_ids = self.scheduler[
1423-
virtual_engine].get_and_reset_finished_requests_ids()
14241427

14251428
# Check if we have a cached last_output from the previous iteration.
14261429
# For supporting PP this is probably the best way to pass the

0 commit comments

Comments
 (0)