|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
4 | 4 | import itertools |
| 5 | +import math |
5 | 6 | from collections.abc import Generator |
6 | 7 | from typing import get_args |
7 | 8 |
|
8 | 9 | import pytest |
9 | 10 | import torch |
10 | 11 |
|
| 12 | +from tests.utils import large_gpu_mark |
11 | 13 | from tests.v1.sample.utils import ( |
12 | 14 | BatchLogprobsComposition, |
13 | 15 | BatchLogprobsSpecType, |
|
17 | 19 | ) |
18 | 20 | from vllm import SamplingParams |
19 | 21 | from vllm.config.model import LogprobsMode |
| 22 | +from vllm.distributed import cleanup_dist_env_and_memory |
20 | 23 |
|
21 | 24 | from ...conftest import HfRunner, VllmRunner |
22 | 25 |
|
@@ -508,3 +511,94 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): |
508 | 511 | if logprobs_mode in ("raw_logits", "processed_logits"): |
509 | 512 | assert positive_values > 0 |
510 | 513 | del llm |
| 514 | + |
| 515 | + |
| 516 | +@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode)) |
| 517 | +@pytest.mark.parametrize( |
| 518 | + "model_setup", |
| 519 | + [ |
| 520 | + pytest.param( |
| 521 | + ( |
| 522 | + "eagle", |
| 523 | + "meta-llama/Llama-3.1-8B-Instruct", |
| 524 | + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", |
| 525 | + ), |
| 526 | + marks=large_gpu_mark(min_gb=32), |
| 527 | + ), |
| 528 | + ], |
| 529 | +) |
| 530 | +def test_spec_decode_logprobs( |
| 531 | + logprobs_mode: LogprobsMode, |
| 532 | + model_setup: tuple[str, str, str], |
| 533 | + monkeypatch: pytest.MonkeyPatch, |
| 534 | +): |
| 535 | + """Spec decode logprobs should match those of the base model. |
| 536 | +
|
| 537 | + Args: |
| 538 | + logprobs_mode: logprobs mode. |
| 539 | + model_setup: Spec decode method, base model name, and |
| 540 | + draft model name. |
| 541 | + """ |
| 542 | + from vllm import LLM |
| 543 | + |
| 544 | + with monkeypatch.context() as m: |
| 545 | + m.setenv("VLLM_USE_V1", "1") |
| 546 | + prompt = "Hello world" |
| 547 | + sampling_params = SamplingParams( |
| 548 | + temperature=0, logprobs=3, max_tokens=10, ignore_eos=False |
| 549 | + ) |
| 550 | + method, model_name, spec_model_name = model_setup |
| 551 | + max_model_len = 256 |
| 552 | + |
| 553 | + # Run base LLM. |
| 554 | + ref_llm = LLM( |
| 555 | + model=model_name, |
| 556 | + max_logprobs=5, |
| 557 | + max_model_len=max_model_len, |
| 558 | + seed=42, |
| 559 | + logprobs_mode=logprobs_mode, |
| 560 | + gpu_memory_utilization=0.4, |
| 561 | + ) |
| 562 | + ref_results = ref_llm.generate([prompt], sampling_params) |
| 563 | + # Collect logprobs outputs from reference LLM. |
| 564 | + ref_logprobs = [] |
| 565 | + for output in ref_results[0].outputs: |
| 566 | + for logprobs in output.logprobs: |
| 567 | + for token_id in logprobs: |
| 568 | + ref_logprobs.append(logprobs[token_id]) |
| 569 | + del ref_llm |
| 570 | + torch.cuda.empty_cache() |
| 571 | + cleanup_dist_env_and_memory() |
| 572 | + |
| 573 | + # Run spec decode LLM. |
| 574 | + spec_llm = LLM( |
| 575 | + model_name, |
| 576 | + speculative_config={ |
| 577 | + "method": method, |
| 578 | + "model": spec_model_name, |
| 579 | + "num_speculative_tokens": 3, |
| 580 | + "max_model_len": max_model_len, |
| 581 | + }, |
| 582 | + max_logprobs=5, |
| 583 | + max_model_len=max_model_len, |
| 584 | + seed=42, |
| 585 | + logprobs_mode=logprobs_mode, |
| 586 | + gpu_memory_utilization=0.4, |
| 587 | + ) |
| 588 | + spec_results = spec_llm.generate([prompt], sampling_params) |
| 589 | + # Collect logprobs outputs from spec decode LLM. |
| 590 | + spec_logprobs = [] |
| 591 | + for output in spec_results[0].outputs: |
| 592 | + for logprobs in output.logprobs: |
| 593 | + for token_id in logprobs: |
| 594 | + spec_logprobs.append(logprobs[token_id]) |
| 595 | + del spec_llm |
| 596 | + torch.cuda.empty_cache() |
| 597 | + cleanup_dist_env_and_memory() |
| 598 | + |
| 599 | + # Per-token logprobs are expected to be the same. |
| 600 | + assert len(ref_logprobs) == len(spec_logprobs) |
| 601 | + for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): |
| 602 | + assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3) |
| 603 | + assert ref_logprob.rank == spec_logprob.rank |
| 604 | + assert ref_logprob.decoded_token == spec_logprob.decoded_token |
0 commit comments