Skip to content

Commit 76712f8

Browse files
committed
fix
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent 2cb84f2 commit 76712f8

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

tests/test_utils.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
import torch
88

9+
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
910
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, bind_kv_cache,
1011
deprecate_kwargs, get_open_port, memory_profiling,
1112
merge_async_iterators, supports_kw)
@@ -323,11 +324,11 @@ def test_bind_kv_cache():
323324
torch.zeros((1, )),
324325
torch.zeros((1, )),
325326
]
326-
bind_kv_cache(ctx, kv_cache)
327-
assert ctx['layers.0.self_attn'].kv_cache is kv_cache[0]
328-
assert ctx['layers.1.self_attn'].kv_cache is kv_cache[1]
329-
assert ctx['layers.2.self_attn'].kv_cache is kv_cache[2]
330-
assert ctx['layers.3.self_attn'].kv_cache is kv_cache[3]
327+
bind_kv_cache(ctx, [kv_cache])
328+
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
329+
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
330+
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
331+
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
331332

332333
def test_bind_kv_cache_non_attention():
333334
from vllm.attention import Attention
@@ -341,9 +342,9 @@ def test_bind_kv_cache_non_attention():
341342
torch.zeros((1, )),
342343
torch.zeros((1, )),
343344
]
344-
bind_kv_cache(ctx, kv_cache)
345-
assert ctx['model.layers.20.attn'].kv_cache is kv_cache[0]
346-
assert ctx['model.layers.28.attn'].kv_cache is kv_cache[1]
345+
bind_kv_cache(ctx, [kv_cache])
346+
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0]
347+
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]
347348

348349

349350
def test_bind_kv_cache_encoder_decoder():
@@ -364,7 +365,24 @@ def test_bind_kv_cache_encoder_decoder():
364365
]
365366
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache
366367

367-
bind_kv_cache(ctx, kv_cache)
368+
bind_kv_cache(ctx, [kv_cache])
368369
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
369-
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache is kv_cache[0]
370-
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache is kv_cache[0]
370+
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
371+
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]
372+
373+
374+
def test_bind_kv_cache_pp():
375+
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
376+
with set_current_vllm_config(cfg):
377+
from vllm.attention import Attention
378+
379+
ctx = {
380+
'layers.0.self_attn': Attention(32, 128, 0.1),
381+
}
382+
kv_cache = [
383+
[torch.zeros((1, ))],
384+
[torch.zeros((1, ))]
385+
]
386+
bind_kv_cache(ctx, kv_cache)
387+
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0]
388+
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]

vllm/worker/hpu_worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,7 @@ def _init_cache_engine(self):
208208
assert self.cache_config.num_gpu_blocks is not None
209209
self.cache_engine = [
210210
HPUCacheEngine(self.cache_config, self.model_config,
211-
self.parallel_config, self.device_config,
212-
self.compilation_config)
211+
self.parallel_config, self.device_config)
213212
for _ in range(self.parallel_config.pipeline_parallel_size)
214213
]
215214
self.hpu_cache = [

0 commit comments

Comments
 (0)