Skip to content

Commit ddb08da

Browse files
Isotr0pyminpeter
authored andcommitted
[Bugfix] Fix TP inference for Flex attention backend (vllm-project#19657)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent 227f607 commit ddb08da

File tree

5 files changed

+54
-2
lines changed

5 files changed

+54
-2
lines changed

tests/v1/engine/test_engine_core.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from vllm.v1.kv_cache_interface import KVCacheConfig
2020
from vllm.v1.outputs import ModelRunnerOutput
2121

22-
from ...utils import create_new_process_for_each_test
22+
from ...utils import create_new_process_for_each_test, multi_gpu_test
2323

2424
if not current_platform.is_cuda():
2525
pytest.skip(reason="V1 currently only supported on CUDA.",
@@ -378,3 +378,37 @@ def shutdown(self):
378378
# Odd steps schedules a new batch.
379379
assert output is None
380380
step += 1
381+
382+
383+
@multi_gpu_test(num_gpus=2)
384+
def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
385+
"""
386+
Test engine can initialize worker in tp properly
387+
"""
388+
389+
with monkeypatch.context() as m:
390+
m.setenv("VLLM_USE_V1", "1")
391+
"""Setup the EngineCore."""
392+
engine_args = EngineArgs(
393+
model=MODEL_NAME,
394+
tensor_parallel_size=2,
395+
# Reduce startup time.
396+
enforce_eager=True,
397+
)
398+
vllm_config = engine_args.create_engine_config()
399+
executor_class = Executor.get_class(vllm_config)
400+
401+
with set_default_torch_num_threads(1):
402+
engine_core = EngineCore(vllm_config=vllm_config,
403+
executor_class=executor_class,
404+
log_stats=True)
405+
406+
def get_worker_cache_config_field(worker, key: str):
407+
return getattr(worker.cache_config, key)
408+
409+
num_gpu_blocks = engine_core.collective_rpc(
410+
get_worker_cache_config_field, args=("num_gpu_blocks", ))
411+
num_cpu_blocks = engine_core.collective_rpc(
412+
get_worker_cache_config_field, args=("num_cpu_blocks", ))
413+
assert all(x is not None for x in num_gpu_blocks)
414+
assert all(x is not None for x in num_cpu_blocks)

vllm/v1/attention/backends/flex_attention.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1414
AttentionMetadata, AttentionType,
1515
is_quantized_kv_cache)
16+
from vllm.distributed import get_tensor_model_parallel_world_size
1617
from vllm.logger import init_logger
1718
from vllm.platforms import current_platform
1819
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
@@ -236,7 +237,12 @@ def final_mask_mod(
236237

237238
def build_block_mask(self) -> BlockMask:
238239
assert self.mask_mod is not None
239-
return create_block_mask_compiled(
240+
# FIXME: With TP>1, create_block_mask_compiled will raise
241+
# CUDA error: an illegal memory access was encountered
242+
create_block_mask_fn = (create_block_mask_compiled
243+
if get_tensor_model_parallel_world_size() == 1
244+
else create_block_mask)
245+
return create_block_mask_fn(
240246
self.mask_mod,
241247
None,
242248
None,

vllm/v1/engine/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def __init__(self,
8484

8585
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
8686
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
87+
self.collective_rpc("initialize_cache",
88+
args=(num_gpu_blocks, num_cpu_blocks))
8789

8890
self.structured_output_manager = StructuredOutputManager(vllm_config)
8991

vllm/v1/worker/gpu_worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
112112
buffer.data.copy_(self._sleep_saved_buffers[name].data)
113113
self._sleep_saved_buffers = {}
114114

115+
def initialize_cache(self, num_gpu_blocks: int,
116+
num_cpu_blocks: int) -> None:
117+
self.cache_config.num_gpu_blocks = num_gpu_blocks
118+
self.cache_config.num_cpu_blocks = num_cpu_blocks
119+
115120
def init_device(self):
116121
if self.device_config.device.type == "cuda":
117122
# torch.distributed.all_reduce does not free the input tensor until

vllm/v1/worker/tpu_worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def __init__(
9393
if self.model_config.seed is None:
9494
self.model_config.seed = 0
9595

96+
def initialize_cache(self, num_gpu_blocks: int,
97+
num_cpu_blocks: int) -> None:
98+
self.cache_config.num_gpu_blocks = num_gpu_blocks
99+
self.cache_config.num_cpu_blocks = num_cpu_blocks
100+
96101
def init_device(self):
97102
os.environ["PJRT_DEVICE"] = "TPU"
98103
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D

0 commit comments

Comments
 (0)