Skip to content

Commit a040fac

Browse files
committed
Sperate model runner
1 parent f3d0bf7 commit a040fac

File tree

5 files changed

+441
-64
lines changed

5 files changed

+441
-64
lines changed

vllm/attention/backends/torch_sdpa.py

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from torch.nn.functional import scaled_dot_product_attention
88

99
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10-
AttentionMetadata,
11-
AttentionMetadataPerStage)
10+
AttentionMetadata, AttentionMetadataPerStage)
1211
from vllm.attention.ops.paged_attn import (PagedAttention,
1312
PagedAttentionMetadata)
1413

@@ -50,20 +49,14 @@ def copy_blocks(
5049

5150

5251
@dataclass
53-
class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
52+
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, AttentionMetadataPerStage):
5453
"""Metadata for TorchSDPABackend.
5554
"""
5655
# Currently, input sequences can only contain all prompts
5756
# or all decoding. True if all sequences are prompts.
5857
is_prompt: bool
58+
slot_mapping: torch.Tensor
5959
prompt_lens: Optional[List[int]]
60-
prompt_lens_tensor: Optional[torch.Tensor]
61-
62-
max_subquery_len: Optional[int] = None
63-
max_prompt_len: Optional[int] = None
64-
subquery_start_loc: Optional[torch.Tensor] = None
65-
seq_start_loc: Optional[torch.Tensor] = None
66-
use_cuda_graph: bool = False
6760

6861
def __post_init__(self):
6962
# Set during the execution of the first attention op.
@@ -111,7 +104,7 @@ def forward(
111104
key: torch.Tensor,
112105
value: torch.Tensor,
113106
kv_cache: Optional[torch.Tensor],
114-
attn_metadata: AttentionMetadata[TorchSDPAMetadata],
107+
attn_metadata: TorchSDPAMetadata,
115108
kv_scale: float,
116109
) -> torch.Tensor:
117110
"""Forward pass with torch SDPA and PagedAttention.
@@ -140,51 +133,36 @@ def forward(
140133
attn_metadata.kv_cache_dtype,
141134
kv_scale)
142135

143-
num_prefill_tokens = attn_metadata.num_prefill_tokens
144-
num_decode_tokens = attn_metadata.num_decode_tokens
145-
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
146-
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
147-
148-
output = torch.empty_like(query)
149-
# Query for decode. KV is not needed because it is already cached.
150-
decode_query = query[num_prefill_tokens:]
151-
# QKV for prefill.
152-
query = query[:num_prefill_tokens]
153-
key = key[:num_prefill_tokens]
154-
value = value[:num_prefill_tokens]
155-
156-
assert query.shape[0] == num_prefill_tokens
157-
assert decode_query.shape[0] == num_decode_tokens
158-
159-
if prefill_meta := attn_metadata.prefill_metadata:
160-
if (kv_cache is None or prefill_meta.block_tables.numel() == 0):
136+
if attn_metadata.is_prompt:
137+
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
161138
if self.num_kv_heads != self.num_heads:
162139
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
163140
value = value.repeat_interleave(self.num_queries_per_kv,
164141
dim=1)
165142

166-
if prefill_meta.attn_bias is None:
143+
if attn_metadata.attn_bias is None:
167144
if self.alibi_slopes is not None:
168145
att_masks = _make_alibi_bias(
169146
self.alibi_slopes, query.dtype,
170-
prefill_meta.prompt_lens) # type: ignore
147+
attn_metadata.prompt_lens) # type: ignore
171148
elif self.sliding_window is not None:
172149
att_masks = _make_sliding_window_bias(
173-
prefill_meta.prompt_lens, self.sliding_window,
150+
attn_metadata.prompt_lens, self.sliding_window,
174151
query.dtype) # type: ignore
175152
else:
176-
att_masks = [None] * len(prefill_meta.prompt_lens)
177-
prefill_meta.attn_bias = att_masks
153+
att_masks = [None] * len(attn_metadata.prompt_lens)
154+
attn_metadata.attn_bias = att_masks
178155

179156
query = query.movedim(0, query.dim() - 2)
180157
key = key.movedim(0, key.dim() - 2)
181158
value = value.movedim(0, value.dim() - 2)
182159

183160
start = 0
184-
out = torch.empty((num_tokens, self.num_heads, self.head_size),
185-
dtype=query.dtype)
186-
for prompt_len, mask in zip(prefill_meta.prompt_lens,
187-
prefill_meta.attn_bias):
161+
output = torch.empty(
162+
(num_tokens, self.num_heads, self.head_size),
163+
dtype=query.dtype)
164+
for prompt_len, mask in zip(attn_metadata.prompt_lens,
165+
attn_metadata.attn_bias):
188166
end = start + prompt_len
189167
sub_out = scaled_dot_product_attention(
190168
query[:, start:end, :],
@@ -194,32 +172,28 @@ def forward(
194172
dropout_p=0.0,
195173
is_causal=not self.need_mask,
196174
scale=self.scale).movedim(query.dim() - 2, 0)
197-
out[start:end, :, :] = sub_out
175+
output[start:end, :, :] = sub_out
198176
start = end
199-
assert out.shape == output[:num_prefill_tokens].shape
200-
output[:num_prefill_tokens] = out
201177
else:
202178
# prefix-enabled attention
203179
raise RuntimeError(
204180
"Torch SDPA backend doesn't support prefix decoding.")
205181

206-
if decode_meta := attn_metadata.decode_metadata:
182+
else:
207183
# Decoding run.
208-
out = PagedAttention.forward_decode(
209-
decode_query,
184+
output = PagedAttention.forward_decode(
185+
query,
210186
key_cache,
211187
value_cache,
212-
decode_meta.block_tables,
213-
decode_meta.context_lens,
214-
decode_meta.max_context_len,
188+
attn_metadata.block_tables,
189+
attn_metadata.context_lens,
190+
attn_metadata.max_context_len,
215191
attn_metadata.kv_cache_dtype,
216192
self.num_kv_heads,
217193
self.scale,
218194
self.alibi_slopes,
219195
kv_scale,
220196
)
221-
assert out.shape == output[num_prefill_tokens:].shape
222-
output[num_prefill_tokens:]
223197

224198
# Reshape the output tensor.
225199
return output.view(-1, self.num_heads * self.head_size)
@@ -241,7 +215,7 @@ def _make_alibi_bias(
241215
bias = bias[None, :] - bias[:, None]
242216

243217
num_heads = alibi_slopes.shape[0]
244-
bias = bias[None, :].expand(num_heads, prompt_len, prompt_len)
218+
bias = bias[None, :].repeat((num_heads, 1, 1))
245219
bias.mul_(alibi_slopes[:, None, None])
246220
inf_mask = torch.empty(
247221
(1, prompt_len, prompt_len),
@@ -270,4 +244,4 @@ def _make_sliding_window_bias(
270244
mask = torch.log(mask)
271245
attn_biases.append(mask.to(dtype))
272246

273-
return attn_biases
247+
return attn_biases

vllm/executor/cpu_executor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
2525
assert lora_config is None, "cpu backend doesn't support LoRA"
2626
model_config = _verify_and_get_model_config(model_config)
2727
cache_config = _verify_and_get_cache_config(cache_config)
28+
scheduler_config = _verify_and_get_scheduler_config(scheduler_config)
2829

2930
self.model_config = model_config
3031
self.cache_config = cache_config
@@ -115,6 +116,12 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
115116
config.enforce_eager = True
116117
return config
117118

119+
def _verify_and_get_scheduler_config(config: SchedulerConfig) -> SchedulerConfig:
120+
if config.chunked_prefill_enabled:
121+
logger.warning("Chunked prefill is not supported on CPU, disable it.")
122+
config.chunked_prefill_enabled = False
123+
124+
return config
118125

119126
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
120127
_GB = 1 << 30

vllm/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,6 @@ def is_pin_memory_available() -> bool:
372372
print_warning_once("Pin memory is not supported on Neuron.")
373373
return False
374374
elif is_cpu():
375-
print_warning_once("Pin memory is not supported on CPU.")
376375
return False
377376
return True
378377

0 commit comments

Comments
 (0)