Skip to content

[V1][TPU] Fix TPU kv sharing tests #19155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 53 additions & 36 deletions tests/v1/tpu/worker/test_tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@
pallas_attention_backend_patcher.start()


@pytest.fixture
def model_runner():
# Patchers have already been started at module level.
def get_vllm_config():
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
Expand All @@ -60,13 +58,24 @@ def model_runner():
cache_config=cache_config,
scheduler_config=scheduler_config,
)
return vllm_config


def get_model_runner(vllm_config):
device = "xla:0" # Mocking TPU device
with mock.patch("vllm.v1.worker.tpu_model_runner.torch"), \
mock.patch("vllm.v1.worker.tpu_model_runner.xm"), \
mock.patch("vllm.v1.worker.tpu_model_runner.xr"):
return TPUModelRunner(vllm_config, device)


@pytest.fixture
def model_runner():
# Patchers have already been started at module level.
vllm_config = get_vllm_config()
return get_model_runner(vllm_config)


@pytest.fixture(autouse=True, scope="session")
def cleanup_patches():
yield
Expand Down Expand Up @@ -370,12 +379,14 @@ def test_get_req_paddings():
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]


@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(
model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
with pytest.raises(ValueError, match=error_msg):
vllm_config = model_runner.vllm_config
with pytest.raises(ValueError, match=error_msg), \
set_current_vllm_config(vllm_config):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
Expand All @@ -399,13 +410,14 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
assert fwd_context is not None


@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
with pytest.raises(ValueError, match=error_msg):
vllm_config = model_runner.vllm_config
with pytest.raises(ValueError, match=error_msg), \
set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Attention(
Expand All @@ -428,12 +440,13 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
assert fwd_context is not None


@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
with pytest.raises(ValueError, match=error_msg):
vllm_config = model_runner.vllm_config
with pytest.raises(ValueError, match=error_msg), \
set_current_vllm_config(vllm_config):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
Expand All @@ -457,11 +470,10 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
assert fwd_context is not None


@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_without_kv_sharing(model_runner):
def test_init_kv_cache_without_kv_sharing():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = model_runner.vllm_config
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Expand All @@ -482,33 +494,38 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_config.model_config.max_model_len = 1_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
model_runner = get_model_runner(vllm_config)
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 2
assert len(model_runner.shared_kv_cache_layers) == 0

available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
# page size for each layer KV can be calculated as
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like kv_cache_config.tensors[layer_0].size should be kv_cache_config.kv_cache_tensors[0], there are several similar cases like this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, this was changed in #17996

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

assert len(kv_cache_config.tensors) == 2
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
assert len(kv_cache_config.kv_cache_tensors) == 2
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2

max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask why we are setting 5GB here, instead of the available memory 5GB?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im just an arbitrarily small memory (5GB) so we can test the logic that constraints the maximum possible model len (3M in this case)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thank you for the explanation, that makes sense

# max context len with KV sharing should be 2x as large as without
assert max_context_len == 1310720
# max_context_len = available_memory / (page_size / block_size) / num_caches
# max_context_len = 5GB / (512KB / 128) / 2 = 655360
assert max_context_len == 655360

# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 32kb)
# this will only allocate 2 block worth of memory (2 * 512kb)
kv_cache_config.num_blocks = 1
for layer in kv_cache_config.tensors:
kv_cache_config.tensors[layer].size =\
kv_cache_spec[layer].page_size_bytes
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
kv_cache_tensor.size = (
kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes)

model_runner.initialize_kv_cache(kv_cache_config)

Expand All @@ -524,11 +541,10 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1


@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_valid(model_runner):
def test_init_kv_cache_with_kv_sharing_valid():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = model_runner.vllm_config
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Expand All @@ -552,33 +568,34 @@ def test_init_kv_cache_with_kv_sharing_valid(model_runner):
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
model_runner = get_model_runner(vllm_config)
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 1
assert layer_0 in kv_cache_spec
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0

available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
# page size for layer 0's kv_cache_spec is 512KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 655360 # 20GB / 32KB
num_expected_blocks = 2 * 20480 # 20GB / 512KB
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 1
assert len(kv_cache_config.kv_cache_tensors) == 1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert kv_cache_config.tensors[layer_0].size == available_memory
assert kv_cache_config.kv_cache_tensors[0].size == available_memory

max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 2 * 1310720
assert max_context_len == (2 * 655360)

# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (32kb)
# this will only allocate 1 block worth of memory (512kb)
kv_cache_config.num_blocks = 1
kv_cache_config.tensors[layer_0].size =\
kv_cache_config.kv_cache_tensors[0].size =\
kv_cache_spec[layer_0].page_size_bytes

model_runner.initialize_kv_cache(kv_cache_config)
Expand Down