Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 33a59a3

Browse files
bigPYJ1151andy-neuma
authored and
andy-neuma
committed
[Hardware][Intel] Isolate CPUModelRunner and ModelRunner for better maintenance (vllm-project#3824)
1 parent f05fb52 commit 33a59a3

File tree

5 files changed

+443
-61
lines changed

5 files changed

+443
-61
lines changed

vllm/attention/backends/torch_sdpa.py

Lines changed: 24 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,15 @@ def copy_blocks(
5050

5151

5252
@dataclass
53-
class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
53+
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
54+
AttentionMetadataPerStage):
5455
"""Metadata for TorchSDPABackend.
5556
"""
5657
# Currently, input sequences can only contain all prompts
5758
# or all decoding. True if all sequences are prompts.
5859
is_prompt: bool
60+
slot_mapping: torch.Tensor
5961
prompt_lens: Optional[List[int]]
60-
prompt_lens_tensor: Optional[torch.Tensor]
61-
62-
max_subquery_len: Optional[int] = None
63-
max_prompt_len: Optional[int] = None
64-
subquery_start_loc: Optional[torch.Tensor] = None
65-
seq_start_loc: Optional[torch.Tensor] = None
66-
use_cuda_graph: bool = False
6762

6863
def __post_init__(self):
6964
# Set during the execution of the first attention op.
@@ -111,7 +106,7 @@ def forward(
111106
key: torch.Tensor,
112107
value: torch.Tensor,
113108
kv_cache: Optional[torch.Tensor],
114-
attn_metadata: AttentionMetadata[TorchSDPAMetadata],
109+
attn_metadata: TorchSDPAMetadata,
115110
kv_scale: float,
116111
) -> torch.Tensor:
117112
"""Forward pass with torch SDPA and PagedAttention.
@@ -140,51 +135,36 @@ def forward(
140135
attn_metadata.kv_cache_dtype,
141136
kv_scale)
142137

143-
num_prefill_tokens = attn_metadata.num_prefill_tokens
144-
num_decode_tokens = attn_metadata.num_decode_tokens
145-
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
146-
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
147-
148-
output = torch.empty_like(query)
149-
# Query for decode. KV is not needed because it is already cached.
150-
decode_query = query[num_prefill_tokens:]
151-
# QKV for prefill.
152-
query = query[:num_prefill_tokens]
153-
key = key[:num_prefill_tokens]
154-
value = value[:num_prefill_tokens]
155-
156-
assert query.shape[0] == num_prefill_tokens
157-
assert decode_query.shape[0] == num_decode_tokens
158-
159-
if prefill_meta := attn_metadata.prefill_metadata:
160-
if (kv_cache is None or prefill_meta.block_tables.numel() == 0):
138+
if attn_metadata.is_prompt:
139+
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
161140
if self.num_kv_heads != self.num_heads:
162141
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
163142
value = value.repeat_interleave(self.num_queries_per_kv,
164143
dim=1)
165144

166-
if prefill_meta.attn_bias is None:
145+
if attn_metadata.attn_bias is None:
167146
if self.alibi_slopes is not None:
168147
att_masks = _make_alibi_bias(
169148
self.alibi_slopes, query.dtype,
170-
prefill_meta.prompt_lens) # type: ignore
149+
attn_metadata.prompt_lens) # type: ignore
171150
elif self.sliding_window is not None:
172151
att_masks = _make_sliding_window_bias(
173-
prefill_meta.prompt_lens, self.sliding_window,
152+
attn_metadata.prompt_lens, self.sliding_window,
174153
query.dtype) # type: ignore
175154
else:
176-
att_masks = [None] * len(prefill_meta.prompt_lens)
177-
prefill_meta.attn_bias = att_masks
155+
att_masks = [None] * len(attn_metadata.prompt_lens)
156+
attn_metadata.attn_bias = att_masks
178157

179158
query = query.movedim(0, query.dim() - 2)
180159
key = key.movedim(0, key.dim() - 2)
181160
value = value.movedim(0, value.dim() - 2)
182161

183162
start = 0
184-
out = torch.empty((num_tokens, self.num_heads, self.head_size),
185-
dtype=query.dtype)
186-
for prompt_len, mask in zip(prefill_meta.prompt_lens,
187-
prefill_meta.attn_bias):
163+
output = torch.empty(
164+
(num_tokens, self.num_heads, self.head_size),
165+
dtype=query.dtype)
166+
for prompt_len, mask in zip(attn_metadata.prompt_lens,
167+
attn_metadata.attn_bias):
188168
end = start + prompt_len
189169
sub_out = scaled_dot_product_attention(
190170
query[:, start:end, :],
@@ -194,32 +174,28 @@ def forward(
194174
dropout_p=0.0,
195175
is_causal=not self.need_mask,
196176
scale=self.scale).movedim(query.dim() - 2, 0)
197-
out[start:end, :, :] = sub_out
177+
output[start:end, :, :] = sub_out
198178
start = end
199-
assert out.shape == output[:num_prefill_tokens].shape
200-
output[:num_prefill_tokens] = out
201179
else:
202180
# prefix-enabled attention
203181
raise RuntimeError(
204182
"Torch SDPA backend doesn't support prefix decoding.")
205183

206-
if decode_meta := attn_metadata.decode_metadata:
184+
else:
207185
# Decoding run.
208-
out = PagedAttention.forward_decode(
209-
decode_query,
186+
output = PagedAttention.forward_decode(
187+
query,
210188
key_cache,
211189
value_cache,
212-
decode_meta.block_tables,
213-
decode_meta.context_lens,
214-
decode_meta.max_context_len,
190+
attn_metadata.block_tables,
191+
attn_metadata.context_lens,
192+
attn_metadata.max_context_len,
215193
attn_metadata.kv_cache_dtype,
216194
self.num_kv_heads,
217195
self.scale,
218196
self.alibi_slopes,
219197
kv_scale,
220198
)
221-
assert out.shape == output[num_prefill_tokens:].shape
222-
output[num_prefill_tokens:]
223199

224200
# Reshape the output tensor.
225201
return output.view(-1, self.num_heads * self.head_size)
@@ -241,7 +217,7 @@ def _make_alibi_bias(
241217
bias = bias[None, :] - bias[:, None]
242218

243219
num_heads = alibi_slopes.shape[0]
244-
bias = bias[None, :].expand(num_heads, prompt_len, prompt_len)
220+
bias = bias[None, :].repeat((num_heads, 1, 1))
245221
bias.mul_(alibi_slopes[:, None, None])
246222
inf_mask = torch.empty(
247223
(1, prompt_len, prompt_len),

vllm/executor/cpu_executor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
2525
assert lora_config is None, "cpu backend doesn't support LoRA"
2626
model_config = _verify_and_get_model_config(model_config)
2727
cache_config = _verify_and_get_cache_config(cache_config)
28+
scheduler_config = _verify_and_get_scheduler_config(scheduler_config)
2829

2930
self.model_config = model_config
3031
self.cache_config = cache_config
@@ -116,6 +117,15 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
116117
return config
117118

118119

120+
def _verify_and_get_scheduler_config(
121+
config: SchedulerConfig) -> SchedulerConfig:
122+
if config.chunked_prefill_enabled:
123+
logger.warning("Chunked prefill is not supported on CPU, disable it.")
124+
config.chunked_prefill_enabled = False
125+
126+
return config
127+
128+
119129
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
120130
_GB = 1 << 30
121131
if config.enable_prefix_caching:

vllm/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,6 @@ def is_pin_memory_available() -> bool:
372372
print_warning_once("Pin memory is not supported on Neuron.")
373373
return False
374374
elif is_cpu():
375-
print_warning_once("Pin memory is not supported on CPU.")
376375
return False
377376
return True
378377

0 commit comments

Comments
 (0)