From 92ef175cfecc826bd456eff7eac5458e41b9d466 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 29 Oct 2024 16:13:20 -0500 Subject: [PATCH] [Bugfix][Frontend] Guard against bad token ids (#9634) Signed-off-by: Joe Runde Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> --- .../entrypoints/llm/test_prompt_validation.py | 8 +++- tests/entrypoints/openai/test_completion.py | 18 ++++----- .../openai/test_prompt_validation.py | 15 +++++++ vllm/engine/async_llm_engine.py | 15 +++++-- vllm/engine/llm_engine.py | 40 +++++++++++++++++-- vllm/transformers_utils/tokenizer.py | 5 +++ vllm/transformers_utils/tokenizers/mistral.py | 5 +++ 7 files changed, 89 insertions(+), 17 deletions(-) diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index 565dfa01346cc..675a980ab3f3f 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -4,6 +4,12 @@ def test_empty_prompt(): - llm = LLM(model="gpt2") + llm = LLM(model="gpt2", enforce_eager=True) with pytest.raises(ValueError, match='Prompt cannot be empty'): llm.generate([""]) + + +def test_out_of_vocab_token(): + llm = LLM(model="gpt2", enforce_eager=True) + with pytest.raises(ValueError, match='out of vocabulary'): + llm.generate({"prompt_token_ids": [999999]}) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index f03bdb045f640..c81cfdbbe5cff 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -157,15 +157,15 @@ async def test_added_lora_tokens(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): # test using token IDs - completion = await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 32000, 32001, 32002], - echo=True, - max_tokens=5, - temperature=0.0, - ) - # Added tokens should not appear in tokenized prompt - assert "vllm" not in completion.choices[0].text + with pytest.raises(openai.BadRequestError, match="out of vocabulary"): + # Added tokens should be rejected by the base model + await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 0a573a0066d32..58075f7023821 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -20,3 +20,18 @@ async def test_empty_prompt(): prompt="", max_tokens=5, temperature=0.0) + + +@pytest.mark.asyncio +async def test_out_of_vocab_token_ids(): + model_name = "gpt2" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + + with pytest.raises(openai.BadRequestError, + match=re.compile('.*out of vocabulary.*')): + await client.completions.create(model=model_name, + prompt=[999999], + max_tokens=5, + temperature=0.0) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e9848a14cbe17..5198467a6ac40 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -412,6 +412,12 @@ async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() + async def get_tokenizer_async(self, + lora_request: Optional[LoRARequest] = None + ) -> AnyTokenizer: + return await ( + self.get_tokenizer_group().get_lora_tokenizer_async(lora_request)) + @overload # DEPRECATED async def add_request_async( self, @@ -472,6 +478,10 @@ async def add_request_async( if arrival_time is None: arrival_time = time.time() + if self.tokenizer is not None: + tokenizer = await self.get_tokenizer_async(lora_request) + self._validate_token_prompt(prompt, tokenizer=tokenizer) + preprocessed_inputs = await self.input_preprocessor.preprocess_async( prompt, request_id=request_id, @@ -488,7 +498,7 @@ async def add_request_async( # implementation in the LLMEngine params = await build_guided_decoding_logits_processor_async( sampling_params=params, - tokenizer=self.get_tokenizer(lora_request), + tokenizer=await self.get_tokenizer_async(lora_request), default_guided_backend=self.decoding_config. guided_decoding_backend) @@ -715,8 +725,7 @@ async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: - return await (self.engine.get_tokenizer_group(). - get_lora_tokenizer_async(lora_request)) + return await self.engine.get_tokenizer_async(lora_request) def start_background_loop(self) -> None: """Start the background loop.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 60575210c9386..fde768ed5165e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -10,7 +10,7 @@ from typing import Set, Type, Union, cast, overload import torch -from typing_extensions import TypeVar +from typing_extensions import TypeIs, TypeVar import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, @@ -32,7 +32,8 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderInputs, InputRegistry, PromptType) + EncoderDecoderInputs, InputRegistry, PromptType, + TokensPrompt) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.logits_process import get_bad_words_logits_processors @@ -667,7 +668,7 @@ def _add_processed_request( ) return None - self._validate_model_inputs(processed_inputs) + self._validate_model_inputs(processed_inputs, lora_request) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) @@ -829,6 +830,11 @@ def add_request( if arrival_time is None: arrival_time = time.time() + if self.tokenizer is not None: + self._validate_token_prompt( + prompt, + tokenizer=self.get_tokenizer(lora_request=lora_request)) + preprocessed_inputs = self.input_preprocessor.preprocess( prompt, request_id=request_id, @@ -855,6 +861,31 @@ def add_request( priority=priority, ) + def _validate_token_prompt(self, prompt: PromptType, + tokenizer: AnyTokenizer): + # Guard against out-of-vocab tokens. + # For some tokenizers, tokenizer.decode will happily return empty text + # for token ids that are out of vocab, and we don't detect token ids + # that are greater than the max token id before running the model. + # However, these token ids will later crash a cuda kernel at runtime + # with an index out of bounds error. This will crash the entire engine. + # This needs to happen before multimodal input pre-processing, which + # may add dummy tokens that aren't part of the tokenizer's + # vocabulary. + if self._is_token_prompt(prompt): + prompt_ids = prompt["prompt_token_ids"] + if len(prompt_ids) == 0: + # Empty prompt check is handled later + return + max_input_id = max(prompt_ids) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + "Token id {} is out of vocabulary".format(max_input_id)) + + @staticmethod + def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: + return isinstance(prompt, dict) and "prompt_token_ids" in prompt + def _create_sequence_group_with_sampling( self, request_id: str, @@ -1942,7 +1973,8 @@ def is_encoder_decoder_model(self): return self.input_preprocessor.is_encoder_decoder_model() def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, - EncoderDecoderInputs]): + EncoderDecoderInputs], + lora_request: Optional[LoRARequest]): if self.model_config.is_multimodal_model: # For encoder-decoder multimodal models, the max_prompt_len # restricts the decoder prompt length diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 94af2388d79db..54f9f895fe541 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -35,6 +35,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: tokenizer.all_special_tokens_extended) tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_len = len(tokenizer) + max_token_id = max(tokenizer.get_vocab().values()) class CachedTokenizer(tokenizer.__class__): # type: ignore @@ -50,6 +51,10 @@ def all_special_tokens(self): def all_special_tokens_extended(self): return tokenizer_all_special_tokens_extended + @property + def max_token_id(self): + return max_token_id + def __len__(self): return tokenizer_len diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 23ea657ffb0a9..80e21c2d32ecc 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -85,6 +85,7 @@ def __init__(self, tokenizer: PublicMistralTokenizer) -> None: raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") self.tokenizer = tokenizer_ + self._max_token_id = max(self._vocab.values()) @classmethod def from_pretrained(cls, @@ -158,6 +159,10 @@ def is_fast(self) -> bool: def vocab_size(self) -> int: return len(self._vocab) + @property + def max_token_id(self) -> int: + return self._max_token_id + def __len__(self) -> int: return self.vocab_size