7
7
from torch .nn .functional import scaled_dot_product_attention
8
8
9
9
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
10
- AttentionMetadata ,
11
- AttentionMetadataPerStage )
10
+ AttentionMetadata , AttentionMetadataPerStage )
12
11
from vllm .attention .ops .paged_attn import (PagedAttention ,
13
12
PagedAttentionMetadata )
14
13
@@ -50,20 +49,14 @@ def copy_blocks(
50
49
51
50
52
51
@dataclass
53
- class TorchSDPAMetadata (AttentionMetadataPerStage , PagedAttentionMetadata ):
52
+ class TorchSDPAMetadata (AttentionMetadata , PagedAttentionMetadata , AttentionMetadataPerStage ):
54
53
"""Metadata for TorchSDPABackend.
55
54
"""
56
55
# Currently, input sequences can only contain all prompts
57
56
# or all decoding. True if all sequences are prompts.
58
57
is_prompt : bool
58
+ slot_mapping : torch .Tensor
59
59
prompt_lens : Optional [List [int ]]
60
- prompt_lens_tensor : Optional [torch .Tensor ]
61
-
62
- max_subquery_len : Optional [int ] = None
63
- max_prompt_len : Optional [int ] = None
64
- subquery_start_loc : Optional [torch .Tensor ] = None
65
- seq_start_loc : Optional [torch .Tensor ] = None
66
- use_cuda_graph : bool = False
67
60
68
61
def __post_init__ (self ):
69
62
# Set during the execution of the first attention op.
@@ -111,7 +104,7 @@ def forward(
111
104
key : torch .Tensor ,
112
105
value : torch .Tensor ,
113
106
kv_cache : Optional [torch .Tensor ],
114
- attn_metadata : AttentionMetadata [ TorchSDPAMetadata ] ,
107
+ attn_metadata : TorchSDPAMetadata ,
115
108
kv_scale : float ,
116
109
) -> torch .Tensor :
117
110
"""Forward pass with torch SDPA and PagedAttention.
@@ -140,51 +133,36 @@ def forward(
140
133
attn_metadata .kv_cache_dtype ,
141
134
kv_scale )
142
135
143
- num_prefill_tokens = attn_metadata .num_prefill_tokens
144
- num_decode_tokens = attn_metadata .num_decode_tokens
145
- assert key .shape [0 ] == num_prefill_tokens + num_decode_tokens
146
- assert value .shape [0 ] == num_prefill_tokens + num_decode_tokens
147
-
148
- output = torch .empty_like (query )
149
- # Query for decode. KV is not needed because it is already cached.
150
- decode_query = query [num_prefill_tokens :]
151
- # QKV for prefill.
152
- query = query [:num_prefill_tokens ]
153
- key = key [:num_prefill_tokens ]
154
- value = value [:num_prefill_tokens ]
155
-
156
- assert query .shape [0 ] == num_prefill_tokens
157
- assert decode_query .shape [0 ] == num_decode_tokens
158
-
159
- if prefill_meta := attn_metadata .prefill_metadata :
160
- if (kv_cache is None or prefill_meta .block_tables .numel () == 0 ):
136
+ if attn_metadata .is_prompt :
137
+ if (kv_cache is None or attn_metadata .block_tables .numel () == 0 ):
161
138
if self .num_kv_heads != self .num_heads :
162
139
key = key .repeat_interleave (self .num_queries_per_kv , dim = 1 )
163
140
value = value .repeat_interleave (self .num_queries_per_kv ,
164
141
dim = 1 )
165
142
166
- if prefill_meta .attn_bias is None :
143
+ if attn_metadata .attn_bias is None :
167
144
if self .alibi_slopes is not None :
168
145
att_masks = _make_alibi_bias (
169
146
self .alibi_slopes , query .dtype ,
170
- prefill_meta .prompt_lens ) # type: ignore
147
+ attn_metadata .prompt_lens ) # type: ignore
171
148
elif self .sliding_window is not None :
172
149
att_masks = _make_sliding_window_bias (
173
- prefill_meta .prompt_lens , self .sliding_window ,
150
+ attn_metadata .prompt_lens , self .sliding_window ,
174
151
query .dtype ) # type: ignore
175
152
else :
176
- att_masks = [None ] * len (prefill_meta .prompt_lens )
177
- prefill_meta .attn_bias = att_masks
153
+ att_masks = [None ] * len (attn_metadata .prompt_lens )
154
+ attn_metadata .attn_bias = att_masks
178
155
179
156
query = query .movedim (0 , query .dim () - 2 )
180
157
key = key .movedim (0 , key .dim () - 2 )
181
158
value = value .movedim (0 , value .dim () - 2 )
182
159
183
160
start = 0
184
- out = torch .empty ((num_tokens , self .num_heads , self .head_size ),
185
- dtype = query .dtype )
186
- for prompt_len , mask in zip (prefill_meta .prompt_lens ,
187
- prefill_meta .attn_bias ):
161
+ output = torch .empty (
162
+ (num_tokens , self .num_heads , self .head_size ),
163
+ dtype = query .dtype )
164
+ for prompt_len , mask in zip (attn_metadata .prompt_lens ,
165
+ attn_metadata .attn_bias ):
188
166
end = start + prompt_len
189
167
sub_out = scaled_dot_product_attention (
190
168
query [:, start :end , :],
@@ -194,32 +172,28 @@ def forward(
194
172
dropout_p = 0.0 ,
195
173
is_causal = not self .need_mask ,
196
174
scale = self .scale ).movedim (query .dim () - 2 , 0 )
197
- out [start :end , :, :] = sub_out
175
+ output [start :end , :, :] = sub_out
198
176
start = end
199
- assert out .shape == output [:num_prefill_tokens ].shape
200
- output [:num_prefill_tokens ] = out
201
177
else :
202
178
# prefix-enabled attention
203
179
raise RuntimeError (
204
180
"Torch SDPA backend doesn't support prefix decoding." )
205
181
206
- if decode_meta := attn_metadata . decode_metadata :
182
+ else :
207
183
# Decoding run.
208
- out = PagedAttention .forward_decode (
209
- decode_query ,
184
+ output = PagedAttention .forward_decode (
185
+ query ,
210
186
key_cache ,
211
187
value_cache ,
212
- decode_meta .block_tables ,
213
- decode_meta .context_lens ,
214
- decode_meta .max_context_len ,
188
+ attn_metadata .block_tables ,
189
+ attn_metadata .context_lens ,
190
+ attn_metadata .max_context_len ,
215
191
attn_metadata .kv_cache_dtype ,
216
192
self .num_kv_heads ,
217
193
self .scale ,
218
194
self .alibi_slopes ,
219
195
kv_scale ,
220
196
)
221
- assert out .shape == output [num_prefill_tokens :].shape
222
- output [num_prefill_tokens :]
223
197
224
198
# Reshape the output tensor.
225
199
return output .view (- 1 , self .num_heads * self .head_size )
@@ -241,7 +215,7 @@ def _make_alibi_bias(
241
215
bias = bias [None , :] - bias [:, None ]
242
216
243
217
num_heads = alibi_slopes .shape [0 ]
244
- bias = bias [None , :].expand ( num_heads , prompt_len , prompt_len )
218
+ bias = bias [None , :].repeat (( num_heads , 1 , 1 ) )
245
219
bias .mul_ (alibi_slopes [:, None , None ])
246
220
inf_mask = torch .empty (
247
221
(1 , prompt_len , prompt_len ),
@@ -270,4 +244,4 @@ def _make_sliding_window_bias(
270
244
mask = torch .log (mask )
271
245
attn_biases .append (mask .to (dtype ))
272
246
273
- return attn_biases
247
+ return attn_biases
0 commit comments