-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[TPU]Fix KV cache sharing tests #19371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
410c452
dfd8276
50deb49
a7b2c32
65fc127
9107b51
ef53a49
6b20e54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
import unittest.mock as mock | ||
|
||
import pytest | ||
|
||
|
@@ -17,24 +16,8 @@ | |
TPUModelRunner, _get_padded_num_reqs_with_upper_limit, | ||
_get_padded_token_len, _get_req_paddings, _get_token_paddings) | ||
|
||
# Mock torch_xla module since it may not be available in the test environments | ||
torch_xla_patcher = mock.patch.dict( | ||
"sys.modules", { | ||
"torch_xla": mock.MagicMock(), | ||
"torch_xla.core.xla_model": mock.MagicMock(), | ||
"torch_xla.runtime": mock.MagicMock(), | ||
}) | ||
torch_xla_patcher.start() | ||
|
||
# Mock the PallasAttentionBackend | ||
pallas_attention_backend_patcher = mock.patch( | ||
"vllm.v1.worker.tpu_model_runner.PallasAttentionBackend", ) | ||
pallas_attention_backend_patcher.start() | ||
|
||
|
||
@pytest.fixture | ||
def model_runner(): | ||
# Patchers have already been started at module level. | ||
def get_vllm_config(): | ||
scheduler_config = SchedulerConfig( | ||
max_num_seqs=10, | ||
max_num_batched_tokens=512, | ||
|
@@ -60,18 +43,19 @@ def model_runner(): | |
cache_config=cache_config, | ||
scheduler_config=scheduler_config, | ||
) | ||
return vllm_config | ||
|
||
|
||
def get_model_runner(vllm_config): | ||
device = "xla:0" # Mocking TPU device | ||
with mock.patch("vllm.v1.worker.tpu_model_runner.torch"), \ | ||
mock.patch("vllm.v1.worker.tpu_model_runner.xm"), \ | ||
mock.patch("vllm.v1.worker.tpu_model_runner.xr"): | ||
return TPUModelRunner(vllm_config, device) | ||
return TPUModelRunner(vllm_config, device) | ||
|
||
|
||
@pytest.fixture(autouse=True, scope="session") | ||
def cleanup_patches(): | ||
yield | ||
torch_xla_patcher.stop() | ||
pallas_attention_backend_patcher.stop() | ||
@pytest.fixture | ||
def model_runner(): | ||
# Patchers have already been started at module level. | ||
vllm_config = get_vllm_config() | ||
return get_model_runner(vllm_config) | ||
|
||
|
||
def _schedule_new_request(*req_ids: str) -> SchedulerOutput: | ||
|
@@ -370,12 +354,14 @@ def test_get_req_paddings(): | |
assert _get_req_paddings(8, 36) == [8, 16, 32, 36] | ||
|
||
|
||
@pytest.mark.skip(reason="Test is broken on TPU when it's added.") | ||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): | ||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order( | ||
model_runner): | ||
layer_0 = "model.layers.0.self_attn.attn" | ||
layer_1 = "model.layers.1.self_attn.attn" | ||
error_msg = f"{layer_1} must come before the current layer" | ||
with pytest.raises(ValueError, match=error_msg): | ||
vllm_config = model_runner.vllm_config | ||
with pytest.raises(ValueError, match=error_msg), \ | ||
set_current_vllm_config(vllm_config): | ||
fwd_context = { | ||
# initialization below will fail because target layer is invalid; | ||
# the target layer needs to come before layer 1 | ||
|
@@ -399,13 +385,14 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): | |
assert fwd_context is not None | ||
|
||
|
||
@pytest.mark.skip(reason="Test is broken on TPU when it's added.") | ||
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): | ||
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner): | ||
layer_0 = "model.layers.0.self_attn.attn" | ||
layer_1 = "model.layers.1.self_attn.attn" | ||
invalid_layer = "model.layers.0.cross_attn.attn" | ||
error_msg = f"{invalid_layer} is not a valid Attention layer in the model" | ||
with pytest.raises(ValueError, match=error_msg): | ||
vllm_config = model_runner.vllm_config | ||
with pytest.raises(ValueError, match=error_msg), \ | ||
set_current_vllm_config(vllm_config): | ||
fwd_context = { | ||
layer_0: | ||
Attention( | ||
|
@@ -428,12 +415,13 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): | |
assert fwd_context is not None | ||
|
||
|
||
@pytest.mark.skip(reason="Test is broken on TPU when it's added.") | ||
def test_init_kv_cache_with_kv_sharing_target_same_as_current(): | ||
def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner): | ||
layer_0 = "model.layers.0.self_attn.attn" | ||
layer_1 = "model.layers.1.self_attn.attn" | ||
error_msg = f"{layer_1} cannot be the same as the current layer" | ||
with pytest.raises(ValueError, match=error_msg): | ||
vllm_config = model_runner.vllm_config | ||
with pytest.raises(ValueError, match=error_msg), \ | ||
set_current_vllm_config(vllm_config): | ||
fwd_context = { | ||
# initialization below will fail because target layer is invalid; | ||
# the target layer needs to come before layer 1 | ||
|
@@ -457,11 +445,10 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current(): | |
assert fwd_context is not None | ||
|
||
|
||
@pytest.mark.skip(reason="Test is broken on TPU when it's added.") | ||
def test_init_kv_cache_without_kv_sharing(model_runner): | ||
def test_init_kv_cache_without_kv_sharing(): | ||
layer_0 = "model.layers.0.self_attn.attn" | ||
layer_1 = "model.layers.1.self_attn.attn" | ||
vllm_config = model_runner.vllm_config | ||
vllm_config = get_vllm_config() | ||
with set_current_vllm_config(vllm_config): | ||
fwd_context = { | ||
layer_0: | ||
|
@@ -482,33 +469,38 @@ def test_init_kv_cache_without_kv_sharing(model_runner): | |
# suppress var not used error | ||
assert fwd_context is not None | ||
# Set high context length to test max context length estimation | ||
vllm_config.model_config.max_model_len = 3_000_000 | ||
vllm_config.model_config.max_model_len = 1_000_000 | ||
vllm_ctx = vllm_config.compilation_config.static_forward_context | ||
model_runner = get_model_runner(vllm_config) | ||
kv_cache_spec = model_runner.get_kv_cache_spec() | ||
assert len(kv_cache_spec) == 2 | ||
assert len(model_runner.shared_kv_cache_layers) == 0 | ||
|
||
available_memory = 20 * GiB_bytes | ||
# page size for layer 0's kv_cache_spec is 32KB | ||
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers) | ||
# page size for each layer KV can be calculated as | ||
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim) | ||
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB | ||
num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers) | ||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, | ||
available_memory) | ||
assert kv_cache_config.num_blocks == num_expected_blocks | ||
assert len(kv_cache_config.tensors) == 2 | ||
assert kv_cache_config.tensors[layer_0].size == available_memory // 2 | ||
assert kv_cache_config.tensors[layer_1].size == available_memory // 2 | ||
assert len(kv_cache_config.kv_cache_tensors) == 2 | ||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 | ||
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2 | ||
|
||
max_context_len =\ | ||
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) | ||
# max context len with KV sharing should be 2x as large as without | ||
assert max_context_len == 1310720 | ||
# max_context_len = available_memory / (page_size / block_size) / num_caches | ||
# max_context_len = 5GB / (512KB / 128) / 2 = 655360 | ||
assert max_context_len == 655360 | ||
|
||
# important: override tensor size to prevent large mem alloc during test | ||
# this will only allocate 2 block worth of memory (2 * 32kb) | ||
# this will only allocate 2 block worth of memory (2 * 512kb) | ||
kv_cache_config.num_blocks = 1 | ||
for layer in kv_cache_config.tensors: | ||
kv_cache_config.tensors[layer].size =\ | ||
kv_cache_spec[layer].page_size_bytes | ||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors: | ||
kv_cache_tensor.size = ( | ||
kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes) | ||
Comment on lines
+501
to
+503
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The change from iterating |
||
|
||
model_runner.initialize_kv_cache(kv_cache_config) | ||
|
||
|
@@ -524,11 +516,10 @@ def test_init_kv_cache_without_kv_sharing(model_runner): | |
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 | ||
|
||
|
||
@pytest.mark.skip(reason="Test is broken on TPU when it's added.") | ||
def test_init_kv_cache_with_kv_sharing_valid(model_runner): | ||
def test_init_kv_cache_with_kv_sharing_valid(): | ||
layer_0 = "model.layers.0.self_attn.attn" | ||
layer_1 = "model.layers.1.self_attn.attn" | ||
vllm_config = model_runner.vllm_config | ||
vllm_config = get_vllm_config() | ||
with set_current_vllm_config(vllm_config): | ||
fwd_context = { | ||
layer_0: | ||
|
@@ -552,33 +543,34 @@ def test_init_kv_cache_with_kv_sharing_valid(model_runner): | |
# Set high context length to test max context length estimation | ||
vllm_config.model_config.max_model_len = 3_000_000 | ||
vllm_ctx = vllm_config.compilation_config.static_forward_context | ||
model_runner = get_model_runner(vllm_config) | ||
kv_cache_spec = model_runner.get_kv_cache_spec() | ||
assert len(kv_cache_spec) == 1 | ||
assert layer_0 in kv_cache_spec | ||
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0 | ||
|
||
available_memory = 20 * GiB_bytes | ||
# page size for layer 0's kv_cache_spec is 32KB | ||
# page size for layer 0's kv_cache_spec is 512KB | ||
# with KV sharing, we can allocate (available_mem//page_size//1) blocks | ||
# which is twice as many as without KV sharing | ||
num_expected_blocks = 655360 # 20GB / 32KB | ||
num_expected_blocks = 2 * 20480 # 20GB / 512KB | ||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, | ||
available_memory) | ||
assert kv_cache_config.num_blocks == num_expected_blocks | ||
assert len(kv_cache_config.tensors) == 1 | ||
assert len(kv_cache_config.kv_cache_tensors) == 1 | ||
# Each layer now has twice the available memory for KV cache | ||
# compared to no KV sharing | ||
assert kv_cache_config.tensors[layer_0].size == available_memory | ||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory | ||
|
||
max_context_len =\ | ||
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) | ||
# max context len with KV sharing should be 2x as large as without | ||
assert max_context_len == 2 * 1310720 | ||
assert max_context_len == (2 * 655360) | ||
|
||
# important: override tensor size to prevent large mem alloc during test | ||
# this will only allocate 1 block worth of memory (32kb) | ||
# this will only allocate 1 block worth of memory (512kb) | ||
kv_cache_config.num_blocks = 1 | ||
kv_cache_config.tensors[layer_0].size =\ | ||
kv_cache_config.kv_cache_tensors[0].size =\ | ||
kv_cache_spec[layer_0].page_size_bytes | ||
|
||
model_runner.initialize_kv_cache(kv_cache_config) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The updated comment explaining the page size calculation is much clearer and more detailed. This is a great improvement for understanding the test's assumptions!