Skip to content

Commit 1025f93

Browse files
committed
raise error to client when pre-request seed is set
Signed-off-by: NickLucche <nlucches@redhat.com>
1 parent e1dab88 commit 1025f93

File tree

4 files changed

+12
-10
lines changed

4 files changed

+12
-10
lines changed

.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ docker run --privileged --net host --shm-size=16G -it \
4242
&& echo TEST_9 \
4343
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \
4444
&& echo TEST_10 \
45-
&& pytest -s -v /workspace/vllm/tests/tpu/test_custom_dispatcher.py" \
45+
&& pytest -s -v /workspace/vllm/tests/tpu/test_custom_dispatcher.py" \

tests/v1/tpu/test_sampler.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ def test_sampler_different(model_name: str):
2828
prompts = [
2929
"Write a short story about a robot that dreams for the first time."
3030
]
31-
sampling_params = SamplingParams(temperature=0.9,
32-
min_p=0.2,
33-
max_tokens=64,
34-
seed=42)
31+
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
3532
output = llm.generate(prompts, sampling_params)
3633

3734
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
3835
output2 = llm.generate(prompts, sampling_params)
3936
assert output[0].outputs[0].text != output2[0].outputs[0].text
37+
38+
with pytest.raises(ValueError):
39+
# Unsupported `seed` param.
40+
sampling_params = SamplingParams(temperature=0.3, seed=42)
41+
output2 = llm.generate(prompts, sampling_params)

vllm/v1/engine/processor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
MultiModalRegistry)
1414
from vllm.multimodal.inputs import PlaceholderRange
1515
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
16+
from vllm.platforms import current_platform
1617
from vllm.pooling_params import PoolingParams
1718
from vllm.prompt_adapter.request import PromptAdapterRequest
18-
from vllm.sampling_params import SamplingParams
19+
from vllm.sampling_params import SamplingParams, SamplingType
1920
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
2021
from vllm.v1.engine import EngineCoreRequest
2122
from vllm.v1.structured_output.backend_guidance import (
@@ -73,6 +74,9 @@ def _validate_sampling_params(
7374
params: SamplingParams,
7475
) -> None:
7576
self._validate_structured_output(params)
77+
if (current_platform.is_tpu()
78+
and params.sampling_type == SamplingType.RANDOM_SEED):
79+
raise ValueError("Torch XLA does not support per-request seed.")
7680

7781
if params.allowed_token_ids is None:
7882
return

vllm/v1/worker/tpu_model_runner.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from vllm.multimodal import MULTIMODAL_REGISTRY
2424
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
2525
from vllm.multimodal.utils import group_mm_inputs_by_modality
26-
from vllm.sampling_params import SamplingType
2726
from vllm.sequence import IntermediateTensors
2827
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
2928
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
@@ -265,9 +264,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
265264
for new_req_data in scheduler_output.scheduled_new_reqs:
266265
req_id = new_req_data.req_id
267266
sampling_params = new_req_data.sampling_params
268-
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
269-
logger.warning("Torch XLA does not support per-request seed."
270-
"Seed {sampling_params.seed} will be ignored")
271267

272268
self.requests[req_id] = CachedRequestState(
273269
req_id=req_id,

0 commit comments

Comments
 (0)