-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[CPU] V1 support for the CPU backend #16441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9bfeaa6
d9c89a8
827b074
bd1e78a
21bcea9
98e44da
660e340
f106d9a
5b2591f
1b25629
df61ca2
413ef08
f44b619
f7de05c
cce8031
cceb5f0
cfd2ce4
deb2f79
84e48ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,10 +86,13 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): | |
# For chunked prefill only | ||
max_query_len: Optional[int] = None | ||
max_kv_len: Optional[int] = None | ||
query_start_loc: Optional[torch.Tensor] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does this file require changes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a naming conflict. The V1 model runner use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would try to avoid the renaming in torch sdpa instead of the global change |
||
prefill_query_start_loc: Optional[torch.Tensor] = None | ||
kv_start_loc: Optional[torch.Tensor] = None | ||
prefill_block_tables: Optional[torch.Tensor] = None | ||
|
||
# For V1 logits index only | ||
query_start_loc: Optional[torch.Tensor] = None | ||
|
||
# Begin encoder attn & enc/dec cross-attn fields... | ||
# Encoder sequence lengths representation | ||
encoder_seq_lens: Optional[List[int]] = None | ||
|
@@ -374,7 +377,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], | |
seq_lens_tensor=seq_lens_tensor, | ||
max_query_len=max_query_len, | ||
max_kv_len=max_kv_len, | ||
query_start_loc=query_start_loc, | ||
prefill_query_start_loc=query_start_loc, | ||
kv_start_loc=kv_start_loc, | ||
max_decode_seq_len=input_data.max_decode_seq_len, | ||
num_prefills=input_data.num_prefills, | ||
|
@@ -466,6 +469,11 @@ def forward( | |
Returns: | ||
shape = [num_tokens, num_heads * head_size] | ||
""" | ||
|
||
# For warming-up | ||
if attn_metadata is None: | ||
return query | ||
|
||
attn_type = self.attn_type | ||
if (attn_type == AttentionType.ENCODER | ||
and (not attn_metadata.is_all_encoder_attn_metadata_set)): | ||
|
@@ -533,8 +541,8 @@ def forward( | |
|
||
output = torch.empty_like(query) | ||
if prefill_meta := attn_metadata.prefill_metadata: | ||
assert attn_metadata.seq_lens is not None | ||
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore | ||
assert attn_metadata.seq_lens is not None | ||
self._run_sdpa_forward(output, | ||
query, | ||
key, | ||
|
@@ -551,7 +559,7 @@ def forward( | |
query[:prefill_meta.num_prefill_tokens, :, :], | ||
key_cache, | ||
value_cache, | ||
prefill_meta.query_start_loc, | ||
prefill_meta.prefill_query_start_loc, | ||
prefill_meta.kv_start_loc, | ||
prefill_meta.max_query_len, | ||
prefill_meta.max_kv_len, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,7 +56,10 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, | |
logger.info("Using CPU MLA backend.") | ||
return "vllm.attention.backends.cpu_mla.CPUMLABackend" | ||
logger.info("Using Torch SDPA backend.") | ||
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" | ||
if use_v1: | ||
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" | ||
else: | ||
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" | ||
|
||
@classmethod | ||
def get_device_total_memory(cls, device_id: int = 0) -> int: | ||
|
@@ -80,6 +83,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
if not model_config.enforce_eager: | ||
model_config.enforce_eager = True | ||
|
||
model_config.disable_cascade_attn = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is new? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I think the cascade attn is only supported in the flash attn backend so I disable it here. I noticed |
||
|
||
cache_config = vllm_config.cache_config | ||
|
||
ipex_available = find_spec("intel_extension_for_pytorch") is not None | ||
|
@@ -127,7 +132,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
f" {kv_cache_space}, expect a positive integer value.") | ||
|
||
parallel_config = vllm_config.parallel_config | ||
if (parallel_config.distributed_executor_backend is not None | ||
if (parallel_config.world_size > 1 | ||
and parallel_config.distributed_executor_backend is not None | ||
and parallel_config.distributed_executor_backend != "mp"): | ||
logger.warning(("%s is not supported on CPU, fallback to mp " | ||
"distributed executor backend."), | ||
|
@@ -140,14 +146,51 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
parallel_config.sd_worker_cls = \ | ||
"vllm.worker.cpu_worker.CPUWorker" | ||
else: | ||
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker" | ||
if envs.VLLM_USE_V1: | ||
parallel_config.worker_cls = \ | ||
"vllm.v1.worker.cpu_worker.CPUWorker" | ||
else: | ||
parallel_config.worker_cls = \ | ||
"vllm.worker.cpu_worker.CPUWorker" | ||
|
||
# Note: workaround for v1 gpu_model_runner | ||
from vllm.config import CompilationLevel | ||
vllm_config.compilation_config.cudagraph_capture_sizes = [] | ||
|
||
compilation_config = vllm_config.compilation_config | ||
if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level | ||
== CompilationLevel.PIECEWISE): | ||
compilation_config.level = CompilationLevel.DYNAMO_ONCE | ||
compilation_config.backend = "eager" | ||
compilation_config.custom_ops += ["none"] | ||
compilation_config.inductor_compile_config.update({ | ||
"dce": | ||
True, | ||
"size_asserts": | ||
False, | ||
"nan_asserts": | ||
False, | ||
"memory_planning": | ||
True, | ||
"epilogue_fusion": | ||
True, | ||
}) | ||
|
||
if vllm_config.lora_config is not None: | ||
compilation_config.level = CompilationLevel.NO_COMPILATION | ||
|
||
assert vllm_config.device_config.device_type == "cpu" | ||
|
||
# | ||
# Environment variables for CPU executor | ||
# | ||
|
||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" | ||
|
||
# Note: to avoid the error 'nthreads cannot be larger than environment | ||
# variable "NUMEXPR_MAX_THREADS" (64)'. | ||
os.environ["NUMEXPR_MAX_THREADS"] = str(len(os.sched_getaffinity(0))) | ||
|
||
# Set default threads num for OpenMP parallel | ||
os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads()) | ||
|
||
|
@@ -170,13 +213,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
# To hint IPEX uses shared memory based AllReduce | ||
os.environ["LOCAL_WORLD_SIZE"] = str( | ||
vllm_config.parallel_config.tensor_parallel_size) | ||
if sys.platform == "darwin" and \ | ||
envs.VLLM_WORKER_MULTIPROC_METHOD == "fork": | ||
if os.environ.get('VLLM_WORKER_MULTIPROC_METHOD', None) is None: | ||
logger.warning( | ||
"Default to spawn method on MacOS. If this is not desired," | ||
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.") | ||
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' | ||
|
||
if vllm_config.model_config and vllm_config.model_config.use_mla: | ||
logger.info( | ||
|
@@ -203,3 +239,14 @@ def get_device_communicator_cls(cls) -> str: | |
Get device specific communicator class for distributed communication. | ||
""" | ||
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa | ||
|
||
@classmethod | ||
def supports_structured_output(cls) -> bool: | ||
return True | ||
|
||
@classmethod | ||
def supports_v1(cls, model_config) -> bool: | ||
"""Returns whether the current platform can support v1 for the supplied | ||
model configuration. | ||
""" | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this file require changes?