Skip to content

Commit

Permalink
[Misc] Remove unnecessary ModelRunner imports (vllm-project#4703)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored May 9, 2024
1 parent 1927716 commit 5654ff4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 73 deletions.
81 changes: 24 additions & 57 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter
from vllm.worker.model_runner import ModelRunner
from vllm.utils import Counter, is_pin_memory_available


class MockLogitsSampler(Sampler):
Expand All @@ -26,20 +25,14 @@ def forward(self, *args, **kwargs):


def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, VOCAB_SIZE),
1e-2,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
return input_tensor, fake_logits, sampler, model_runner
return input_tensor, fake_logits, sampler


VOCAB_SIZE = 32000
Expand All @@ -53,7 +46,6 @@ def _do_sample(
batch_size: int,
input_tensor: torch.Tensor,
sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams,
device: str,
):
Expand All @@ -75,7 +67,7 @@ def _do_sample(
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
pin_memory=is_pin_memory_available())
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)


Expand All @@ -85,28 +77,24 @@ def test_sampler_all_greedy(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
input_tensor, fake_logits, sampler = _prepare_test(batch_size)

sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == expected[i].item()

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
_, fake_logits, sampler = _prepare_test(batch_size)

for i in range(batch_size):
fake_logits[i, i] = 1e2
Expand All @@ -115,23 +103,21 @@ def test_sampler_all_random(seed: int, device: str):
temperature=1.0,
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)

for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
_, fake_logits, sampler = _prepare_test(batch_size)

for i in range(batch_size):
fake_logits[i, i] = 1e2
Expand All @@ -141,60 +127,54 @@ def test_sampler_all_random_seed(seed: int, device: str):
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)

for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
_, fake_logits, sampler = _prepare_test(batch_size)

sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params, device)
sampling_params, device)

second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params, device)
sampling_params, device)

assert first_sampler_output == second_sampler_output

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
_, fake_logits, sampler = _prepare_test(batch_size)

sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params,
device)
_do_sample(batch_size, fake_logits, sampler, sampling_params, device)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
Expand Down Expand Up @@ -448,13 +428,13 @@ def run_test_case(*,
("Invalid test case, expected_penalization does not match computed"
"batch size")

_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
_, fake_logits, sampler = _prepare_test(batch_size)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens=seq_lens if seq_lens else None,
query_lens=seq_lens if seq_lens else None,
device=device,
pin_memory=model_runner.pin_memory)
pin_memory=is_pin_memory_available())
# the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

Expand All @@ -480,8 +460,6 @@ def run_test_case(*,
fake_logits[logits_idx, :] ==
-float('inf')) == 0, "No tokens should have been penalized"

del model_runner

for test_case in test_cases:
run_test_case(**test_case)

Expand All @@ -492,8 +470,7 @@ def test_sampler_mixed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
input_tensor, fake_logits, sampler = _prepare_test(batch_size)

seq_group_metadata_list = []
expected_tokens: List[Optional[List[int]]] = []
Expand Down Expand Up @@ -534,13 +511,13 @@ def test_sampler_mixed(seed: int, device: str):
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

def test_sampling(model_runner: ModelRunner):
def test_sampling():
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
pin_memory=is_pin_memory_available())
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

Expand Down Expand Up @@ -570,7 +547,7 @@ def test_sampling(model_runner: ModelRunner):
assert nth_output.output_token in expected_tokens[i]

# Test batch
test_sampling(model_runner)
test_sampling()

# Shuffle the batch and resample
target_index = list(range(batch_size))
Expand All @@ -583,9 +560,7 @@ def test_sampling(model_runner: ModelRunner):

# This time, results of seeded random samples will be compared with
# the corresponding sample in the pre-shuffled batch
test_sampling(model_runner)

del model_runner
test_sampling()


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
Expand All @@ -605,12 +580,6 @@ def test_sampler_top_k_top_p(seed: int, device: str):
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)

generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k,
Expand Down Expand Up @@ -641,7 +610,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
pin_memory=is_pin_memory_available())

sample_probs = None

Expand All @@ -657,5 +626,3 @@ def mock_sample(probs, *args, **kwargs):
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))

del model_runner
23 changes: 7 additions & 16 deletions tests/test_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
from vllm.utils import is_pin_memory_available


class MockLogitsProcessor(LogitsProcessor):
Expand All @@ -30,21 +30,15 @@ def forward(self, *args, **kwargs):


def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]:
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
return input_tensor, fake_logits, logits_processor, model_runner
return input_tensor, fake_logits, logits_processor


RANDOM_SEEDS = list(range(128))
Expand All @@ -59,8 +53,7 @@ def test_logits_processors(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, logits_processor, model_runner = _prepare_test(
batch_size)
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)

# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
Expand All @@ -87,8 +80,8 @@ def pick_ith(token_ids, logits):
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
device=device,
pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor(
embedding=None,
hidden_states=input_tensor,
Expand All @@ -99,5 +92,3 @@ def pick_ith(token_ids, logits):
fake_logits *= logits_processor.scale
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
1e-4)

del model_runner

0 comments on commit 5654ff4

Please sign in to comment.