6
6
import pytest
7
7
import torch
8
8
9
+ from vllm .config import ParallelConfig , VllmConfig , set_current_vllm_config
9
10
from vllm .utils import (FlexibleArgumentParser , StoreBoolean , bind_kv_cache ,
10
11
deprecate_kwargs , get_open_port , memory_profiling ,
11
12
merge_async_iterators , supports_kw )
@@ -323,11 +324,11 @@ def test_bind_kv_cache():
323
324
torch .zeros ((1 , )),
324
325
torch .zeros ((1 , )),
325
326
]
326
- bind_kv_cache (ctx , kv_cache )
327
- assert ctx ['layers.0.self_attn' ].kv_cache is kv_cache [0 ]
328
- assert ctx ['layers.1.self_attn' ].kv_cache is kv_cache [1 ]
329
- assert ctx ['layers.2.self_attn' ].kv_cache is kv_cache [2 ]
330
- assert ctx ['layers.3.self_attn' ].kv_cache is kv_cache [3 ]
327
+ bind_kv_cache (ctx , [ kv_cache ] )
328
+ assert ctx ['layers.0.self_attn' ].kv_cache [ 0 ] is kv_cache [0 ]
329
+ assert ctx ['layers.1.self_attn' ].kv_cache [ 0 ] is kv_cache [1 ]
330
+ assert ctx ['layers.2.self_attn' ].kv_cache [ 0 ] is kv_cache [2 ]
331
+ assert ctx ['layers.3.self_attn' ].kv_cache [ 0 ] is kv_cache [3 ]
331
332
332
333
def test_bind_kv_cache_non_attention ():
333
334
from vllm .attention import Attention
@@ -341,9 +342,9 @@ def test_bind_kv_cache_non_attention():
341
342
torch .zeros ((1 , )),
342
343
torch .zeros ((1 , )),
343
344
]
344
- bind_kv_cache (ctx , kv_cache )
345
- assert ctx ['model.layers.20.attn' ].kv_cache is kv_cache [0 ]
346
- assert ctx ['model.layers.28.attn' ].kv_cache is kv_cache [1 ]
345
+ bind_kv_cache (ctx , [ kv_cache ] )
346
+ assert ctx ['model.layers.20.attn' ].kv_cache [ 0 ] is kv_cache [0 ]
347
+ assert ctx ['model.layers.28.attn' ].kv_cache [ 0 ] is kv_cache [1 ]
347
348
348
349
349
350
def test_bind_kv_cache_encoder_decoder ():
@@ -364,7 +365,24 @@ def test_bind_kv_cache_encoder_decoder():
364
365
]
365
366
encoder_kv_cache = ctx ['encoder.layers.0.self_attn.attn' ].kv_cache
366
367
367
- bind_kv_cache (ctx , kv_cache )
368
+ bind_kv_cache (ctx , [ kv_cache ] )
368
369
assert ctx ['encoder.layers.0.self_attn.attn' ].kv_cache is encoder_kv_cache
369
- assert ctx ['decoder.layers.0.encoder_attn.attn' ].kv_cache is kv_cache [0 ]
370
- assert ctx ['decoder.layers.0.self_attn.attn' ].kv_cache is kv_cache [0 ]
370
+ assert ctx ['decoder.layers.0.encoder_attn.attn' ].kv_cache [0 ] is kv_cache [0 ]
371
+ assert ctx ['decoder.layers.0.self_attn.attn' ].kv_cache [0 ] is kv_cache [0 ]
372
+
373
+
374
+ def test_bind_kv_cache_pp ():
375
+ cfg = VllmConfig (parallel_config = ParallelConfig (pipeline_parallel_size = 2 ))
376
+ with set_current_vllm_config (cfg ):
377
+ from vllm .attention import Attention
378
+
379
+ ctx = {
380
+ 'layers.0.self_attn' : Attention (32 , 128 , 0.1 ),
381
+ }
382
+ kv_cache = [
383
+ [torch .zeros ((1 , ))],
384
+ [torch .zeros ((1 , ))]
385
+ ]
386
+ bind_kv_cache (ctx , kv_cache )
387
+ assert ctx ['layers.0.self_attn' ].kv_cache [0 ] is kv_cache [0 ][0 ]
388
+ assert ctx ['layers.0.self_attn' ].kv_cache [1 ] is kv_cache [1 ][0 ]
0 commit comments