|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
4 | | -import pytest |
5 | 4 | import torch |
6 | 5 |
|
7 | | -from vllm.model_executor.layers.sampler import SamplerOutput |
8 | | -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, |
9 | | - SequenceData, SequenceOutput) |
10 | | - |
11 | | -from .core.utils import create_dummy_prompt |
12 | | - |
13 | | - |
14 | | -@pytest.fixture |
15 | | -def sample_outputs(): |
16 | | - return [ |
17 | | - CompletionSequenceGroupOutput(samples=[ |
18 | | - SequenceOutput(parent_seq_id=0, output_token=i, logprobs={}) |
19 | | - ], |
20 | | - prompt_logprobs=None) for i in range(5) |
21 | | - ] |
22 | | - |
23 | | - |
24 | | -@pytest.fixture |
25 | | -def sampler_output(sample_outputs): |
26 | | - return SamplerOutput(outputs=sample_outputs) |
27 | | - |
28 | | - |
29 | | -def test_sampler_output_initialization(sampler_output, sample_outputs): |
30 | | - assert len(sampler_output) == len(sample_outputs) |
31 | | - assert sampler_output.sampled_token_probs is None |
32 | | - assert sampler_output.sampled_token_ids is None |
33 | | - |
34 | | - |
35 | | -def test_sampler_output_getitem(sampler_output, sample_outputs): |
36 | | - assert sampler_output[2] == sample_outputs[2] |
37 | | - |
38 | | - |
39 | | -def test_sampler_output_setitem(sampler_output): |
40 | | - new_output = CompletionSequenceGroupOutput(samples=[ |
41 | | - SequenceOutput(parent_seq_id=0, output_token=99, logprobs={}) |
42 | | - ], |
43 | | - prompt_logprobs=None) |
44 | | - sampler_output[2] = new_output |
45 | | - assert sampler_output[2] == new_output |
46 | | - |
47 | | - |
48 | | -def test_sampler_output_len(sampler_output, sample_outputs): |
49 | | - assert len(sampler_output) == len(sample_outputs) |
50 | | - |
51 | | - |
52 | | -def test_sampler_output_eq(sample_outputs): |
53 | | - sampler_output1 = SamplerOutput(outputs=sample_outputs) |
54 | | - sampler_output2 = SamplerOutput(outputs=sample_outputs.copy()) |
55 | | - sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1]) |
56 | | - assert sampler_output1 == sampler_output2 |
57 | | - assert sampler_output1 != sampler_output3 |
58 | | - |
59 | | - |
60 | | -def test_sequence_data_prefill(): |
61 | | - seq_data = SequenceData.from_seqs([1, 2, 3, 4]) |
62 | | - assert seq_data.get_num_uncomputed_tokens() == 4 |
63 | | - assert seq_data.get_num_computed_tokens() == 0 |
64 | | - # advance by 2 |
65 | | - seq_data.update_num_computed_tokens(2) |
66 | | - assert seq_data.get_num_uncomputed_tokens() == 2 |
67 | | - assert seq_data.get_num_computed_tokens() == 2 |
68 | | - |
69 | | - # advance by 1 |
70 | | - seq_data.update_num_computed_tokens(1) |
71 | | - assert seq_data.get_num_uncomputed_tokens() == 1 |
72 | | - assert seq_data.get_num_computed_tokens() == 3 |
73 | | - |
74 | | - # append tokens and reset, simulating recompute |
75 | | - seq_data.append_token_id(1, logprob=0.0) |
76 | | - seq_data.reset_state_for_recompute() |
77 | | - assert seq_data.get_num_uncomputed_tokens() == 5 |
78 | | - assert seq_data.get_num_computed_tokens() == 0 |
79 | | - |
80 | | - |
81 | | -def test_sequence_group_stage(): |
82 | | - _, seq_group = create_dummy_prompt("1", 12) |
83 | | - assert seq_group.is_prefill() is True |
84 | | - seq_group.update_num_computed_tokens(6) |
85 | | - assert seq_group.is_prefill() is True |
86 | | - seq_group.update_num_computed_tokens(5) |
87 | | - assert seq_group.is_prefill() is True |
88 | | - seq_group.update_num_computed_tokens(1) |
89 | | - assert seq_group.is_prefill() is False |
90 | | - seqs = seq_group.get_seqs() |
91 | | - assert len(seqs) == 1 |
92 | | - seqs[0].data.append_token_id(1, logprob=0.0) |
93 | | - for seq in seq_group.get_seqs(): |
94 | | - seq.reset_state_for_recompute() |
95 | | - assert seq_group.is_prefill() is True |
96 | | - seq_group.update_num_computed_tokens(5) |
97 | | - assert seq_group.is_prefill() is True |
98 | | - seq_group.update_num_computed_tokens(7) |
99 | | - assert seq_group.is_prefill() is True |
100 | | - seq_group.update_num_computed_tokens(1) |
101 | | - assert seq_group.is_prefill() is False |
| 6 | +from vllm.sequence import IntermediateTensors |
102 | 7 |
|
103 | 8 |
|
104 | 9 | def test_sequence_intermediate_tensors_equal(): |
|
0 commit comments