Skip to content

Commit 83debd5

Browse files
nie3eChristos Malliopoulos
authored andcommitted
[Model] GPT2ForSequenceClassification model (vllm-project#19663)
Signed-off-by: nie3e <adrcwiek@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> added notebooks to playground updates remoted verbatim HF secrets from all files updates [custom_op][vllm-plugin] update custom_op class to use op_registry (vllm-project#19164) Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Export NaNs in logits to scheduler_stats if output is corrupted (vllm-project#18777) Signed-off-by: Vlad Mihailescu <vtmihailescu@gmail.com> [CPU][CI] Fallback sliding window to v0 and fix CPU pooling model tests (vllm-project#19901) Signed-off-by: jiang1.li <jiang1.li@intel.com> [Kernel] mark TorchSDPABackend swap_blocks NotImplementedError (vllm-project#19749)
1 parent 7771d1d commit 83debd5

File tree

24 files changed

+1364
-11
lines changed

24 files changed

+1364
-11
lines changed

NOTES.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# TL; DR
2+
I log here my reading conprehension notes and tasks regarding the repository.
3+
The `var` directory is contained in `.gitignore` so it does not mix with the repo code.
4+
5+
6+
7+
# `docker/Dockerfile.arm`
8+
The command `ENV LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libtcmalloc_minimal.so.4"`
9+
sets the container environment variable LD_PRELOAD to the reference shared library
10+
(note this is an Ubuntu image created on an arm64 host).
11+
12+
(ChatGPT): LD_PRELOAD instructs the dynamic linker to load a shared library before any other library when running executables.
13+
It allows you to override functions in system libraries or inject extra functionality without changing the application binary.
14+
15+
`tcmalloc_minimal` is the minimal version of Google’s TCMalloc library (Thread-Caching Malloc),
16+
an optimized memory allocator from the "Google Performance Tools suite". It provides faster malloc/free
17+
than the default system allocator (glibc malloc). Helps improve performance of memory-intensive apps.
18+
19+
The vLLM documentation on docker installation, contains a build command for x86 cpus:
20+
```bash
21+
$ docker build -f docker/Dockerfile.cpu --tag vllm-cpu-env --target vllm-openai .
22+
```
23+
24+
If we replace (as requested) `Dockerfile.cpu` by `Dockerfile.arm` the build fails. This is because `--target vllm-openai`
25+
refers to a stage in `Dockerfile.cpu` that is not contained in `Dockerfile.arm`.
26+
See the [docker documentation](https://docs.docker.com/build/building/multi-stage/) on multi-staged builds (there
27+
it is also explained what happens with dockerfiles with multiple `FROM` commands).
28+
29+
Here we use the command `$ docker build -f docker/Dockerfile.cpu --tag vllm-openai:arm .`.
30+
31+
32+
33+
# Additional installation requirements
34+
1. For the chat REPL of `transformers` you need to `pip install accelerate`.
35+
36+
2. To run the `generate` SDK it is recommended to `pip install bitsandbytes` (by Huggingface).
37+
`bitsandbytes` has methods for quantizing (when loading to memory) LLMs that greately improves performance.
38+
39+
40+
41+

tests/models/language/pooling/test_embedding.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import os
4+
35
import pytest
46

57
from vllm.config import PoolerConfig
@@ -33,7 +35,7 @@ def v1(run_with_both_engines):
3335
# To avoid this problem, for now we skip v0 since it will be
3436
# deprecated anyway.
3537
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
36-
marks=[pytest.mark.skip_v0]),
38+
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
3739
# [Encoder-only]
3840
pytest.param("BAAI/bge-base-en-v1.5",
3941
marks=[
@@ -58,6 +60,9 @@ def test_models(
5860
model,
5961
monkeypatch,
6062
) -> None:
63+
if model == "intfloat/e5-mistral-7b-instruct" and current_platform.is_cpu(
64+
) and os.environ.get("VLLM_USE_V1", "0") == "1":
65+
pytest.skip("CPU V1 doesn't support sliding window")
6166

6267
if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm():
6368
# ROCm Triton FA does not currently support sliding window attention

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def check_available_online(
267267
# [Text-only]
268268
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
269269
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
270+
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
270271
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
271272
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
272273
trust_remote_code=True),

tests/plugins/vllm_add_dummy_platform/setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@
1010
entry_points={
1111
'vllm.platform_plugins': [
1212
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
13-
]
13+
],
14+
"vllm.general_plugins":
15+
["dummy_custom_ops = vllm_add_dummy_platform:register_ops"],
1416
})

tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,7 @@
66

77
def dummy_platform_plugin() -> Optional[str]:
88
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
9+
10+
11+
def register_ops():
12+
import vllm_add_dummy_platform.dummy_custom_ops # noqa

tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from vllm.attention.backends.flash_attn import FlashAttentionBackend
4+
from vllm.attention.backends.placeholder_attn import (
5+
PlaceholderAttentionBackend)
56

67

7-
class DummyAttentionBackend(FlashAttentionBackend):
8+
class DummyAttentionBackend(PlaceholderAttentionBackend):
89

910
@staticmethod
1011
def get_name() -> str:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
6+
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
7+
8+
9+
# Register CustomRotaryEmbedding to CustomOP.
10+
@RotaryEmbedding.register_oot
11+
class DummyRotaryEmbedding(RotaryEmbedding):
12+
"""Original rotary positional embedding."""
13+
14+
def __init__(self, *args, **kwargs):
15+
super().__init__(*args, **kwargs)
16+
self.addition_config = True
17+
18+
def forward_oot(self, *args,
19+
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
20+
return super().forward_oot(*args, **kwargs)
Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,29 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import TYPE_CHECKING
34

4-
from vllm.platforms.cuda import CudaPlatform
5+
from vllm.platforms.interface import Platform, PlatformEnum
56

7+
if TYPE_CHECKING:
8+
from vllm.config import VllmConfig
9+
else:
10+
VllmConfig = None
11+
from vllm import envs
612

7-
class DummyPlatform(CudaPlatform):
13+
14+
class DummyPlatform(Platform):
15+
_enum = PlatformEnum.OOT
816
device_name = "DummyDevice"
17+
device_type: str = "privateuseone"
18+
dispatch_key: str = "PrivateUse1"
19+
20+
@classmethod
21+
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
22+
if envs.VLLM_USE_V1:
23+
compilation_config = vllm_config.compilation_config
24+
# Activate custom ops for v1.
25+
compilation_config.custom_ops = ["all"]
926

1027
def get_attn_backend_cls(self, backend_name, head_size, dtype,
1128
kv_cache_dtype, block_size, use_v1, use_mla):
12-
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
29+
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

tests/plugins_tests/test_platform_plugins.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
from vllm.attention.selector import get_attn_backend
8+
from vllm.plugins import load_general_plugins
89
from vllm.utils import STR_BACKEND_ENV_VAR, STR_INVALID_VAL
910

1011

@@ -32,3 +33,16 @@ def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch):
3233
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
3334
backend = get_attn_backend(16, torch.float16, "auto", 16, False)
3435
assert backend.get_name() == "Dummy_Backend"
36+
37+
38+
def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch):
39+
# simulate workload by running an example
40+
load_general_plugins()
41+
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
42+
layer = RotaryEmbedding(16, 16, 16, 16, True, torch.float16)
43+
assert layer.__class__.__name__ == "DummyRotaryEmbedding", (
44+
f"Expected DummyRotaryEmbedding, got {layer.__class__.__name__}, "
45+
"possibly because the custom op is not registered correctly.")
46+
assert hasattr(layer, "addition_config"), (
47+
"Expected DummyRotaryEmbedding to have an 'addition_config' attribute, "
48+
"which is set by the custom op.")

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import random
55

66
import pytest
7+
import torch
78

89
from vllm.attention import Attention
910
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
@@ -277,6 +278,54 @@ def test_update_states_request_resumed(model_runner):
277278
assert _is_req_state_block_table_match(model_runner, req_id)
278279

279280

281+
def test_get_nans_in_logits(model_runner):
282+
req_ids = ("req_0", "req_1")
283+
284+
scheduler_output = _schedule_new_request(*req_ids)
285+
model_runner._update_states(scheduler_output)
286+
287+
logits = torch.tensor([
288+
[1.0, 2.0, 3.0],
289+
[3.0, 2.0, 1.0],
290+
], device=DEVICE)
291+
result = model_runner._get_nans_in_logits(logits)
292+
assert result == {"req_0": 0, "req_1": 0}
293+
294+
logits = torch.tensor([
295+
[1.0, float('nan'), 3.0],
296+
[4.0, float('nan'), float('nan')],
297+
],
298+
device=DEVICE)
299+
result = model_runner._get_nans_in_logits(logits)
300+
assert result == {"req_0": 1, "req_1": 2}
301+
302+
logits = torch.tensor([
303+
[1.0, 2.0, 3.0],
304+
[4.0, float('nan'), float('nan')],
305+
],
306+
device=DEVICE)
307+
result = model_runner._get_nans_in_logits(logits)
308+
assert result == {"req_0": 0, "req_1": 2}
309+
310+
result = model_runner._get_nans_in_logits(logits=None)
311+
assert result == {"req_0": 0, "req_1": 0}
312+
313+
logits = torch.tensor([
314+
[1.0, float('nan'), 3.0],
315+
], device=DEVICE)
316+
result = model_runner._get_nans_in_logits(logits)
317+
assert result == {'req_0': 1, 'req_1': 0}
318+
319+
logits = torch.tensor([
320+
[float('nan'), float('nan'), 2.0],
321+
[1.0, 2.0, 3.0],
322+
[float('nan'), 2.0, 3.0],
323+
],
324+
device=DEVICE)
325+
result = model_runner._get_nans_in_logits(logits)
326+
assert result == {'req_0': 2, 'req_1': 0}
327+
328+
280329
def test_update_states_no_changes(model_runner):
281330
req_id = "req_0"
282331

0 commit comments

Comments
 (0)