Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 00f60d2

Browse files
DarkLight1337Robert Shaw
authored andcommitted
[Bugfix] Support eos_token_id from config.json (vllm-project#5954)
1 parent 33fecd4 commit 00f60d2

File tree

3 files changed

+67
-11
lines changed

3 files changed

+67
-11
lines changed

tests/tokenization/test_get_eos.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""
2+
This test file includes some cases where it is inappropriate to
3+
only get the `eos_token_id` from the tokenizer as defined by
4+
:meth:`vllm.LLMEngine._get_eos_token_id`.
5+
"""
6+
from vllm.transformers_utils.config import try_get_generation_config
7+
from vllm.transformers_utils.tokenizer import get_tokenizer
8+
9+
10+
def test_get_llama3_eos_token():
11+
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
12+
13+
tokenizer = get_tokenizer(model_name)
14+
assert tokenizer.eos_token_id == 128009
15+
16+
generation_config = try_get_generation_config(model_name,
17+
trust_remote_code=False)
18+
assert generation_config is not None
19+
assert generation_config.eos_token_id == [128001, 128009]
20+
21+
22+
def test_get_blip2_eos_token():
23+
model_name = "Salesforce/blip2-opt-2.7b"
24+
25+
tokenizer = get_tokenizer(model_name)
26+
assert tokenizer.eos_token_id == 2
27+
28+
generation_config = try_get_generation_config(model_name,
29+
trust_remote_code=False)
30+
assert generation_config is not None
31+
assert generation_config.eos_token_id == 50118

vllm/engine/llm_engine.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import time
22
from contextlib import contextmanager
3-
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional
3+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
44
from typing import Sequence as GenericSequence
55
from typing import Set, Type, TypeVar, Union
66

7-
from transformers import GenerationConfig, PreTrainedTokenizer
7+
from transformers import PreTrainedTokenizer
88

99
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
1010
LoRAConfig, ModelConfig, ObservabilityConfig,
@@ -34,6 +34,7 @@
3434
SequenceStatus)
3535
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
3636
init_tracer)
37+
from vllm.transformers_utils.config import try_get_generation_config
3738
from vllm.transformers_utils.detokenizer import Detokenizer
3839
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
3940
get_tokenizer_group)
@@ -46,16 +47,18 @@
4647
_LOCAL_LOGGING_INTERVAL_SEC = 5
4748

4849

49-
def _load_generation_config_dict(model_config: ModelConfig):
50-
try:
51-
return GenerationConfig.from_pretrained(
52-
model_config.model,
53-
revision=model_config.revision,
54-
).to_diff_dict()
55-
except OSError:
56-
# Not found.
50+
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
51+
config = try_get_generation_config(
52+
model_config.model,
53+
trust_remote_code=model_config.trust_remote_code,
54+
revision=model_config.revision,
55+
)
56+
57+
if config is None:
5758
return {}
5859

60+
return config.to_diff_dict()
61+
5962

6063
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
6164

vllm/transformers_utils/config.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import contextlib
22
from typing import Dict, Optional, Type
33

4-
from transformers import PretrainedConfig
4+
from transformers import GenerationConfig, PretrainedConfig
55

66
from vllm.envs import VLLM_USE_MODELSCOPE
77
from vllm.logger import init_logger
@@ -80,3 +80,25 @@ def get_hf_text_config(config: PretrainedConfig):
8080
return config.text_config
8181
else:
8282
return config
83+
84+
85+
def try_get_generation_config(
86+
model: str,
87+
trust_remote_code: bool,
88+
revision: Optional[str] = None,
89+
) -> Optional[GenerationConfig]:
90+
try:
91+
return GenerationConfig.from_pretrained(
92+
model,
93+
revision=revision,
94+
)
95+
except OSError: # Not found
96+
try:
97+
config = get_config(
98+
model,
99+
trust_remote_code=trust_remote_code,
100+
revision=revision,
101+
)
102+
return GenerationConfig.from_model_config(config)
103+
except OSError: # Not found
104+
return None

0 commit comments

Comments
 (0)