Skip to content

[Bugfix] Fix TP inference for Flex attention backend #19657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput

from ...utils import create_new_process_for_each_test
from ...utils import create_new_process_for_each_test, multi_gpu_test

if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
Expand Down Expand Up @@ -378,3 +378,37 @@ def shutdown(self):
# Odd steps schedules a new batch.
assert output is None
step += 1


@multi_gpu_test(num_gpus=2)
def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
"""
Test engine can initialize worker in tp properly
"""

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(
model=MODEL_NAME,
tensor_parallel_size=2,
# Reduce startup time.
enforce_eager=True,
)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)

with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)

def get_worker_cache_config_field(worker, key: str):
return getattr(worker.cache_config, key)

num_gpu_blocks = engine_core.collective_rpc(
get_worker_cache_config_field, args=("num_gpu_blocks", ))
num_cpu_blocks = engine_core.collective_rpc(
get_worker_cache_config_field, args=("num_cpu_blocks", ))
assert all(x is not None for x in num_gpu_blocks)
assert all(x is not None for x in num_cpu_blocks)
8 changes: 7 additions & 1 deletion vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
Expand Down Expand Up @@ -236,7 +237,12 @@ def final_mask_mod(

def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None
return create_block_mask_compiled(
# FIXME: With TP>1, create_block_mask_compiled will raise
# CUDA error: an illegal memory access was encountered
Comment on lines +240 to +241
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The FIXME comment clearly explains the issue with create_block_mask_compiled when the tensor parallel world size is greater than 1. To ensure this is addressed in the future, consider creating a GitHub issue to track this underlying CUDA error if one doesn't exist already. This would help in eventually enabling the compiled version universally.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The full trace back of the illegal memory error:

Log
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/vllm/v1/attention/backends/flex_attention.py", line 262, in __post_init__
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     self.block_mask = self.build_block_mask()
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                       ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/vllm/v1/attention/backends/flex_attention.py", line 246, in build_block_mask
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return create_block_mask_fn(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 824, in create_block_mask
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     def create_block_mask(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1201, in forward
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return compiled_fn(full_args)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 328, in runtime_wrapper
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     all_outs = call_func_at_runtime_with_args(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = normalize_as_list(f(args))
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                             ^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 689, in inner_fn
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     outs = compiled_fn(args)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 495, in wrapper
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return compiled_fn(runtime_args)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 460, in __call__
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return self.current_callable(inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1372, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return compiled_fn(new_inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 387, in deferred_cudagraphify
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 448, in cudagraphify
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return manager.add_function(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2308, in add_function
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn, fn(inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                ^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 1997, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = self._run(new_inputs, function_id)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2104, in _run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = self.run_eager(new_inputs, function_id)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2269, in run_eager
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return node.run(new_inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 668, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = self.wrapped_function.model(new_inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/root/.cache/vllm/torch_compile_cache/26b5568570/rank_0_0/inductor_cache/4o/c4osf7wcdszj5dy7kaxakhrrucni4ac5aiyysa63j3fmz37p6jxn.py", line 561, in call
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     triton_per_fused__to_copy_sum_7.run(buf18, buf22, 5718, triton_per_fused__to_copy_sum_7_r0_numel, stream=stream0)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 909, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     self.autotune_to_one_config(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 763, in autotune_to_one_config
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     timings = self.benchmark_all_configs(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 738, in benchmark_all_configs
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     launcher: self.bench(launcher, *args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 616, in bench
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return benchmarker.benchmark_gpu(kernel_call, rep=40)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 39, in wrapper
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn(self, *args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 243, in benchmark_gpu
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     _callable()
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 601, in kernel_call
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     launcher(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "<string>", line 5, in launcher
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 444, in __call__
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     self.launch(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527] RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

@drisspg Any idea about this error?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @zou3519 for torch.compile related issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey so I actually just noticed this too, this was not the cause until pretty recently, going to create an issue + tracking for this

create_block_mask_fn = (create_block_mask_compiled
if get_tensor_model_parallel_world_size() == 1
else create_block_mask)
return create_block_mask_fn(
self.mask_mod,
None,
None,
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __init__(self,

vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering why only TP + FlexAttention needs this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because FlexAttention needs num_gpu_blocks for calculation while other attention backends don't need it.

Not sure if this is intended, but in V1, only engine core's cache_config has updated num_gpu_blocks, and worker in different process (TP situation) won't have num_gpu_blocks updated without collective_rpc calling.

Therefore, in distributed inference, worker's num_gpu_blocks is still None, which caused the error in PR description.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Could you check if we need to add some condition to only call this function if tp > 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For single-process, we use UniProcExecutor instead of MultiprocExecutor:

elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
elif distributed_executor_backend == "uni":
executor_class = UniProcExecutor

Given that it also has collective_rpc impplemented properly, it's safe to call collective_rpc as well, especially we only update cache_config here, though it has been done in unified process with previous lines before:

def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
if kwargs is None:
kwargs = {}
answer = run_method(self.driver_worker, method, args, kwargs)
return [answer]

Have checked TP=1 can still work currently.

args=(num_gpu_blocks, num_cpu_blocks))

self.structured_output_manager = StructuredOutputManager(vllm_config)

Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}

def initialize_cache(self, num_gpu_blocks: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds more like "setting_cache_size" instead of initialize_cache?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, cache_config's num_gpu_blocks and num_cpu_blocks are updated in initialize_cache for worker in v0, which is a base class method:

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError

vllm/vllm/worker/worker.py

Lines 312 to 325 in 3d330c4

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
"""
raise_if_cache_size_invalid(
num_gpu_blocks, self.cache_config.block_size,
self.cache_config.is_attention_free,
self.model_config.max_model_len,
self.parallel_config.pipeline_parallel_size)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

Although this method not used by v1 before this PR, I think using this method shared by v0 can keep the worker implementation consistent.

num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

def init_device(self):
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def __init__(
if self.model_config.seed is None:
self.model_config.seed = 0

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
Expand Down