Skip to content

Commit 4338cc4

Browse files
authored
[Tokenizer] Add an option to specify tokenizer (#284)
1 parent bdd6b4c commit 4338cc4

File tree

10 files changed

+61
-60
lines changed

10 files changed

+61
-60
lines changed

benchmarks/benchmark_latency.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def main(args: argparse.Namespace):
1717
# the engine will automatically process the request in multiple batches.
1818
llm = LLM(
1919
model=args.model,
20+
tokenizer=args.tokenizer,
2021
tensor_parallel_size=args.tensor_parallel_size,
2122
max_num_seqs=args.batch_size,
2223
max_num_batched_tokens=args.batch_size * args.input_len,
@@ -63,6 +64,7 @@ def run_to_completion(profile: bool = False):
6364
description='Benchmark the latency of processing a single batch of '
6465
'requests till completion.')
6566
parser.add_argument('--model', type=str, default='facebook/opt-125m')
67+
parser.add_argument('--tokenizer', type=str, default=None)
6668
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
6769
parser.add_argument('--input-len', type=int, default=32)
6870
parser.add_argument('--output-len', type=int, default=128)

benchmarks/benchmark_serving.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,13 @@
2424

2525
import aiohttp
2626
import numpy as np
27-
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
27+
from transformers import PreTrainedTokenizerBase
28+
from vllm.transformers_utils.tokenizer import get_tokenizer
2829

2930
# (prompt len, output len, latency)
3031
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
3132

3233

33-
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
34-
config = AutoConfig.from_pretrained(model_name)
35-
if config.model_type == "llama":
36-
# A workaround for potential protobuf errors.
37-
model_name = "hf-internal-testing/llama-tokenizer"
38-
return AutoTokenizer.from_pretrained(model_name)
39-
40-
4134
def sample_requests(
4235
dataset_path: str,
4336
num_requests: int,

benchmarks/benchmark_throughput.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,11 @@
66
from typing import List, Tuple
77

88
import torch
9-
from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM,
10-
PreTrainedTokenizerBase)
9+
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
1110
from tqdm import tqdm
1211

1312
from vllm import LLM, SamplingParams
14-
15-
16-
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
17-
config = AutoConfig.from_pretrained(model_name)
18-
if config.model_type == "llama":
19-
# A workaround for potential protobuf errors.
20-
model_name = "hf-internal-testing/llama-tokenizer"
21-
tokenizer = AutoTokenizer.from_pretrained(model_name)
22-
# To enable padding in the HF backend.
23-
tokenizer.pad_token = tokenizer.eos_token
24-
return tokenizer
25-
return AutoTokenizer.from_pretrained(model_name)
13+
from vllm.transformers_utils.tokenizer import get_tokenizer
2614

2715

2816
def sample_requests(
@@ -74,13 +62,15 @@ def sample_requests(
7462
def run_vllm(
7563
requests: List[Tuple[str, int, int]],
7664
model: str,
65+
tokenizer: str,
7766
tensor_parallel_size: int,
7867
seed: int,
7968
n: int,
8069
use_beam_search: bool,
8170
) -> float:
8271
llm = LLM(
8372
model=model,
73+
tokenizer=tokenizer,
8474
tensor_parallel_size=tensor_parallel_size,
8575
seed=seed,
8676
)
@@ -118,9 +108,10 @@ def run_hf(
118108
max_batch_size: int,
119109
) -> float:
120110
assert not use_beam_search
121-
tokenizer = get_tokenizer(model)
122-
llm = AutoModelForCausalLM.from_pretrained(
123-
model, torch_dtype=torch.float16)
111+
llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
112+
if llm.config.model_type == "llama":
113+
# To enable padding in the HF backend.
114+
tokenizer.pad_token = tokenizer.eos_token
124115
llm = llm.cuda()
125116

126117
pbar = tqdm(total=len(requests))
@@ -170,13 +161,13 @@ def main(args: argparse.Namespace):
170161
random.seed(args.seed)
171162

172163
# Sample the requests.
173-
tokenizer = get_tokenizer(args.model)
164+
tokenizer = get_tokenizer(args.tokenizer)
174165
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
175166

176167
if args.backend == "vllm":
177168
elapsed_time = run_vllm(
178-
requests, args.model, args.tensor_parallel_size, args.seed, args.n,
179-
args.use_beam_search)
169+
requests, args.model, args.tokenizer, args.tensor_parallel_size,
170+
args.seed, args.n, args.use_beam_search)
180171
elif args.backend == "hf":
181172
assert args.tensor_parallel_size == 1
182173
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -198,6 +189,7 @@ def main(args: argparse.Namespace):
198189
parser.add_argument("--dataset", type=str, required=True,
199190
help="Path to the dataset.")
200191
parser.add_argument("--model", type=str, default="facebook/opt-125m")
192+
parser.add_argument("--tokenizer", type=str, default=None)
201193
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
202194
parser.add_argument("--n", type=int, default=1,
203195
help="Number of generated sequences per prompt.")
@@ -208,11 +200,14 @@ def main(args: argparse.Namespace):
208200
parser.add_argument("--hf-max-batch-size", type=int, default=None,
209201
help="Maximum batch size for HF backend.")
210202
args = parser.parse_args()
203+
211204
if args.backend == "vllm":
212205
if args.hf_max_batch_size is not None:
213206
raise ValueError("HF max batch size is only for HF backend.")
214207
elif args.backend == "hf":
215208
if args.hf_max_batch_size is None:
216209
raise ValueError("HF max batch size is required for HF backend.")
210+
if args.tokenizer is None:
211+
args.tokenizer = args.model
217212

218213
main(args)

vllm/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class ModelConfig:
1616
1717
Args:
1818
model: Name or path of the huggingface model to use.
19+
tokenizer: Name or path of the huggingface tokenizer to use.
1920
download_dir: Directory to download and load the weights, default to the
2021
default cache directory of huggingface.
2122
use_np_weights: Save a numpy copy of model weights for faster loading.
@@ -30,13 +31,15 @@ class ModelConfig:
3031
def __init__(
3132
self,
3233
model: str,
34+
tokenizer: Optional[str],
3335
download_dir: Optional[str],
3436
use_np_weights: bool,
3537
use_dummy_weights: bool,
3638
dtype: str,
3739
seed: int,
3840
) -> None:
3941
self.model = model
42+
self.tokenizer = tokenizer
4043
self.download_dir = download_dir
4144
self.use_np_weights = use_np_weights
4245
self.use_dummy_weights = use_dummy_weights

vllm/engine/arg_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
class EngineArgs:
1212
"""Arguments for vLLM engine."""
1313
model: str
14+
tokenizer: Optional[str] = None
1415
download_dir: Optional[str] = None
1516
use_np_weights: bool = False
1617
use_dummy_weights: bool = False
@@ -27,6 +28,8 @@ class EngineArgs:
2728
disable_log_stats: bool = False
2829

2930
def __post_init__(self):
31+
if self.tokenizer is None:
32+
self.tokenizer = self.model
3033
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
3134

3235
@staticmethod
@@ -37,6 +40,8 @@ def add_cli_args(
3740
# Model arguments
3841
parser.add_argument('--model', type=str, default='facebook/opt-125m',
3942
help='name or path of the huggingface model to use')
43+
parser.add_argument('--tokenizer', type=str, default=EngineArgs.tokenizer,
44+
help='name or path of the huggingface tokenizer to use')
4045
parser.add_argument('--download-dir', type=str,
4146
default=EngineArgs.download_dir,
4247
help='directory to download and load the weights, '
@@ -104,7 +109,7 @@ def create_engine_configs(
104109
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
105110
# Initialize the configs.
106111
model_config = ModelConfig(
107-
self.model, self.download_dir, self.use_np_weights,
112+
self.model, self.tokenizer, self.download_dir, self.use_np_weights,
108113
self.use_dummy_weights, self.dtype, self.seed)
109114
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
110115
self.swap_space)

vllm/engine/llm_engine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
from vllm.core.scheduler import Scheduler
77
from vllm.engine.arg_utils import EngineArgs
88
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray
9-
from vllm.engine.tokenizer_utils import detokenize_incrementally, get_tokenizer
109
from vllm.logger import init_logger
1110
from vllm.outputs import RequestOutput
1211
from vllm.sampling_params import SamplingParams
1312
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
13+
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
14+
get_tokenizer)
1415
from vllm.utils import Counter
1516
from vllm.worker.worker import Worker
1617

@@ -59,6 +60,7 @@ def __init__(
5960
logger.info(
6061
"Initializing an LLM engine with config: "
6162
f"model={model_config.model!r}, "
63+
f"tokenizer={model_config.tokenizer!r}, "
6264
f"dtype={model_config.dtype}, "
6365
f"use_dummy_weights={model_config.use_dummy_weights}, "
6466
f"download_dir={model_config.download_dir!r}, "
@@ -75,7 +77,7 @@ def __init__(
7577
self.log_stats = log_stats
7678
self._verify_args()
7779

78-
self.tokenizer = get_tokenizer(model_config.model)
80+
self.tokenizer = get_tokenizer(model_config.tokenizer)
7981
self.seq_counter = Counter()
8082

8183
# Create the parallel GPU workers.

vllm/entrypoints/llm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class LLM:
2525
2626
Args:
2727
model: The name or path of a HuggingFace Transformers model.
28+
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
2829
tensor_parallel_size: The number of GPUs to use for distributed
2930
execution with tensor parallelism.
3031
dtype: The data type for the model weights and activations. Currently,
@@ -38,6 +39,7 @@ class LLM:
3839
def __init__(
3940
self,
4041
model: str,
42+
tokenizer: Optional[str] = None,
4143
tensor_parallel_size: int = 1,
4244
dtype: str = "auto",
4345
seed: int = 0,
@@ -47,6 +49,7 @@ def __init__(
4749
kwargs["disable_log_stats"] = True
4850
engine_args = EngineArgs(
4951
model=model,
52+
tokenizer=tokenizer,
5053
tensor_parallel_size=tensor_parallel_size,
5154
dtype=dtype,
5255
seed=seed,

vllm/entrypoints/openai/api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616
from vllm.engine.arg_utils import AsyncEngineArgs
1717
from vllm.engine.async_llm_engine import AsyncLLMEngine
18-
from vllm.engine.tokenizer_utils import get_tokenizer
1918
from vllm.entrypoints.openai.protocol import (
2019
CompletionRequest, CompletionResponse, CompletionResponseChoice,
2120
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
2221
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
2322
from vllm.logger import init_logger
2423
from vllm.outputs import RequestOutput
2524
from vllm.sampling_params import SamplingParams
25+
from vllm.transformers_utils.tokenizer import get_tokenizer
2626
from vllm.utils import random_uuid
2727

2828
TIMEOUT_KEEP_ALIVE = 5 # seconds

vllm/transformers_utils/__init__.py

Whitespace-only changes.

vllm/engine/tokenizer_utils.py renamed to vllm/transformers_utils/tokenizer.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,44 @@
11
from typing import List, Tuple, Union
22

3-
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
3+
from transformers import (AutoTokenizer, PreTrainedTokenizer,
44
PreTrainedTokenizerFast)
55

66
from vllm.logger import init_logger
77

88
logger = init_logger(__name__)
99

10-
_MODEL_TYPES_WITH_SLOW_TOKENIZER = []
10+
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
11+
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
1112

1213

1314
def get_tokenizer(
14-
model_name: str,
15+
tokenizer_name: str,
1516
*args,
1617
**kwargs,
1718
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
1819
"""Gets a tokenizer for the given model name via Huggingface."""
19-
config = AutoConfig.from_pretrained(model_name)
20-
if "open_llama" in model_name:
21-
kwargs["use_fast"] = False
20+
if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True):
2221
logger.info(
23-
"OpenLLaMA models do not support the fast tokenizer. "
24-
"Using the slow tokenizer instead.")
25-
elif config.model_type == "llama" and kwargs.get("use_fast", True):
26-
# LLaMA fast tokenizer causes protobuf errors in some environments.
27-
# However, we found that the below LLaMA fast tokenizer works well in
28-
# most environments.
29-
model_name = "hf-internal-testing/llama-tokenizer"
30-
logger.info(
31-
f"Using the LLaMA fast tokenizer in '{model_name}' to avoid "
32-
"potential protobuf errors.")
33-
elif config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
34-
if kwargs.get("use_fast", False) == True:
35-
raise ValueError(
36-
f"Cannot use the fast tokenizer for {config.model_type} due to "
37-
"bugs in the fast tokenizer.")
38-
logger.info(
39-
f"Using the slow tokenizer for {config.model_type} due to bugs in "
40-
"the fast tokenizer. This could potentially lead to performance "
41-
"degradation.")
42-
kwargs["use_fast"] = False
43-
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
22+
"For some LLaMA-based models, initializing the fast tokenizer may "
23+
"take a long time. To eliminate the initialization time, consider "
24+
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
25+
"tokenizer.")
26+
try:
27+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args,
28+
**kwargs)
29+
except TypeError as e:
30+
# The LLaMA tokenizer causes a protobuf error in some environments.
31+
err_msg = (
32+
"Failed to load the tokenizer. If you are using a LLaMA-based "
33+
f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original "
34+
"tokenizer.")
35+
raise RuntimeError(err_msg) from e
36+
37+
if not isinstance(tokenizer, PreTrainedTokenizerFast):
38+
logger.warning(
39+
"Using a slow tokenizer. This might cause a significant "
40+
"slowdown. Consider using a fast tokenizer instead.")
41+
return tokenizer
4442

4543

4644
def detokenize_incrementally(

0 commit comments

Comments
 (0)