forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Asynchronous tokenization (vllm-project#2879)
- Loading branch information
Showing
17 changed files
with
658 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
from transformers import AutoTokenizer, PreTrainedTokenizerBase | ||
from vllm.lora.request import LoRARequest | ||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group | ||
from vllm.transformers_utils.tokenizer import get_lora_tokenizer | ||
from ..conftest import get_tokenizer_pool_config | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) | ||
async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): | ||
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) | ||
tokenizer_group = get_tokenizer_group( | ||
get_tokenizer_pool_config(tokenizer_group_type), | ||
tokenizer_id="gpt2", | ||
enable_lora=True, | ||
max_num_seqs=1, | ||
max_input_length=None, | ||
) | ||
lora_request = LoRARequest("1", 1, sql_lora_files) | ||
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( | ||
request_id="request_id", prompt="prompt", lora_request=lora_request) | ||
assert reference_tokenizer.encode( | ||
"prompt") == await tokenizer_group.encode_async( | ||
request_id="request_id", | ||
prompt="prompt", | ||
lora_request=lora_request) | ||
assert isinstance(tokenizer_group.get_lora_tokenizer(None), | ||
PreTrainedTokenizerBase) | ||
assert tokenizer_group.get_lora_tokenizer( | ||
None) == await tokenizer_group.get_lora_tokenizer_async(None) | ||
|
||
assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request), | ||
PreTrainedTokenizerBase) | ||
assert tokenizer_group.get_lora_tokenizer( | ||
lora_request) != tokenizer_group.get_lora_tokenizer(None) | ||
assert tokenizer_group.get_lora_tokenizer( | ||
lora_request) == await tokenizer_group.get_lora_tokenizer_async( | ||
lora_request) | ||
|
||
|
||
def test_get_lora_tokenizer(sql_lora_files, tmpdir): | ||
lora_request = None | ||
tokenizer = get_lora_tokenizer(lora_request) | ||
assert not tokenizer | ||
|
||
lora_request = LoRARequest("1", 1, sql_lora_files) | ||
tokenizer = get_lora_tokenizer(lora_request) | ||
assert tokenizer.get_added_vocab() | ||
|
||
lora_request = LoRARequest("1", 1, str(tmpdir)) | ||
tokenizer = get_lora_tokenizer(lora_request) | ||
assert not tokenizer |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from copy import deepcopy | ||
from vllm.transformers_utils.tokenizer import get_cached_tokenizer | ||
from transformers import AutoTokenizer | ||
|
||
|
||
def test_cached_tokenizer(): | ||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") | ||
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"}) | ||
reference_tokenizer.add_special_tokens( | ||
{"additional_special_tokens": ["<SEP>"]}) | ||
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) | ||
|
||
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( | ||
"prompt") | ||
assert set(reference_tokenizer.all_special_ids) == set( | ||
cached_tokenizer.all_special_ids) | ||
assert set(reference_tokenizer.all_special_tokens) == set( | ||
cached_tokenizer.all_special_tokens) | ||
assert set(reference_tokenizer.all_special_tokens_extended) == set( | ||
cached_tokenizer.all_special_tokens_extended) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import os | ||
import pytest | ||
import asyncio | ||
from unittest.mock import patch | ||
|
||
from transformers import AutoTokenizer, PreTrainedTokenizerBase | ||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group | ||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( | ||
RayTokenizerGroupPool) | ||
from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( | ||
TokenizerGroup) | ||
from ..conftest import get_tokenizer_pool_config | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) | ||
async def test_tokenizer_group(tokenizer_group_type): | ||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") | ||
tokenizer_group = get_tokenizer_group( | ||
get_tokenizer_pool_config(tokenizer_group_type), | ||
tokenizer_id="gpt2", | ||
enable_lora=False, | ||
max_num_seqs=1, | ||
max_input_length=None, | ||
) | ||
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( | ||
request_id="request_id", prompt="prompt", lora_request=None) | ||
assert reference_tokenizer.encode( | ||
"prompt") == await tokenizer_group.encode_async( | ||
request_id="request_id", prompt="prompt", lora_request=None) | ||
assert isinstance(tokenizer_group.get_lora_tokenizer(None), | ||
PreTrainedTokenizerBase) | ||
assert tokenizer_group.get_lora_tokenizer( | ||
None) == await tokenizer_group.get_lora_tokenizer_async(None) | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) | ||
async def test_tokenizer_group_pool(tokenizer_group_type): | ||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") | ||
tokenizer_group_pool = get_tokenizer_group( | ||
get_tokenizer_pool_config(tokenizer_group_type), | ||
tokenizer_id="gpt2", | ||
enable_lora=False, | ||
max_num_seqs=1, | ||
max_input_length=None, | ||
) | ||
# Send multiple requests to the tokenizer group pool | ||
# (more than the pool size) | ||
# and check that all requests are processed correctly. | ||
num_requests = tokenizer_group_pool.pool_size * 5 | ||
requests = [ | ||
tokenizer_group_pool.encode_async(request_id=str(i), | ||
prompt=f"prompt {i}", | ||
lora_request=None) | ||
for i in range(num_requests) | ||
] | ||
results = await asyncio.gather(*requests) | ||
expected_results = [ | ||
reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests) | ||
] | ||
assert results == expected_results | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) | ||
async def test_tokenizer_group_ray_pool_env_var_propagation( | ||
tokenizer_group_type): | ||
"""Test that env vars from caller process are propagated to | ||
tokenizer Ray actors.""" | ||
env_var = "MY_ENV_VAR" | ||
|
||
class EnvVarCheckerTokenizerGroup(TokenizerGroup): | ||
|
||
def ping(self): | ||
assert os.environ.get(env_var) == "1" | ||
return super().ping() | ||
|
||
class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool): | ||
_worker_cls = EnvVarCheckerTokenizerGroup | ||
|
||
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) | ||
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( | ||
tokenizer_pool_config, | ||
tokenizer_id="gpt2", | ||
enable_lora=False, | ||
max_num_seqs=1, | ||
max_input_length=None) | ||
with pytest.raises(AssertionError): | ||
tokenizer_pool.ping() | ||
|
||
with patch.dict(os.environ, {env_var: "1"}): | ||
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) | ||
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( | ||
tokenizer_pool_config, | ||
tokenizer_id="gpt2", | ||
enable_lora=False, | ||
max_num_seqs=1, | ||
max_input_length=None) | ||
tokenizer_pool.ping() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.