@@ -50,20 +50,15 @@ def copy_blocks(
50
50
51
51
52
52
@dataclass
53
- class TorchSDPAMetadata (AttentionMetadataPerStage , PagedAttentionMetadata ):
53
+ class TorchSDPAMetadata (AttentionMetadata , PagedAttentionMetadata ,
54
+ AttentionMetadataPerStage ):
54
55
"""Metadata for TorchSDPABackend.
55
56
"""
56
57
# Currently, input sequences can only contain all prompts
57
58
# or all decoding. True if all sequences are prompts.
58
59
is_prompt : bool
60
+ slot_mapping : torch .Tensor
59
61
prompt_lens : Optional [List [int ]]
60
- prompt_lens_tensor : Optional [torch .Tensor ]
61
-
62
- max_subquery_len : Optional [int ] = None
63
- max_prompt_len : Optional [int ] = None
64
- subquery_start_loc : Optional [torch .Tensor ] = None
65
- seq_start_loc : Optional [torch .Tensor ] = None
66
- use_cuda_graph : bool = False
67
62
68
63
def __post_init__ (self ):
69
64
# Set during the execution of the first attention op.
@@ -111,7 +106,7 @@ def forward(
111
106
key : torch .Tensor ,
112
107
value : torch .Tensor ,
113
108
kv_cache : Optional [torch .Tensor ],
114
- attn_metadata : AttentionMetadata [ TorchSDPAMetadata ] ,
109
+ attn_metadata : TorchSDPAMetadata ,
115
110
kv_scale : float ,
116
111
) -> torch .Tensor :
117
112
"""Forward pass with torch SDPA and PagedAttention.
@@ -140,51 +135,36 @@ def forward(
140
135
attn_metadata .kv_cache_dtype ,
141
136
kv_scale )
142
137
143
- num_prefill_tokens = attn_metadata .num_prefill_tokens
144
- num_decode_tokens = attn_metadata .num_decode_tokens
145
- assert key .shape [0 ] == num_prefill_tokens + num_decode_tokens
146
- assert value .shape [0 ] == num_prefill_tokens + num_decode_tokens
147
-
148
- output = torch .empty_like (query )
149
- # Query for decode. KV is not needed because it is already cached.
150
- decode_query = query [num_prefill_tokens :]
151
- # QKV for prefill.
152
- query = query [:num_prefill_tokens ]
153
- key = key [:num_prefill_tokens ]
154
- value = value [:num_prefill_tokens ]
155
-
156
- assert query .shape [0 ] == num_prefill_tokens
157
- assert decode_query .shape [0 ] == num_decode_tokens
158
-
159
- if prefill_meta := attn_metadata .prefill_metadata :
160
- if (kv_cache is None or prefill_meta .block_tables .numel () == 0 ):
138
+ if attn_metadata .is_prompt :
139
+ if (kv_cache is None or attn_metadata .block_tables .numel () == 0 ):
161
140
if self .num_kv_heads != self .num_heads :
162
141
key = key .repeat_interleave (self .num_queries_per_kv , dim = 1 )
163
142
value = value .repeat_interleave (self .num_queries_per_kv ,
164
143
dim = 1 )
165
144
166
- if prefill_meta .attn_bias is None :
145
+ if attn_metadata .attn_bias is None :
167
146
if self .alibi_slopes is not None :
168
147
att_masks = _make_alibi_bias (
169
148
self .alibi_slopes , query .dtype ,
170
- prefill_meta .prompt_lens ) # type: ignore
149
+ attn_metadata .prompt_lens ) # type: ignore
171
150
elif self .sliding_window is not None :
172
151
att_masks = _make_sliding_window_bias (
173
- prefill_meta .prompt_lens , self .sliding_window ,
152
+ attn_metadata .prompt_lens , self .sliding_window ,
174
153
query .dtype ) # type: ignore
175
154
else :
176
- att_masks = [None ] * len (prefill_meta .prompt_lens )
177
- prefill_meta .attn_bias = att_masks
155
+ att_masks = [None ] * len (attn_metadata .prompt_lens )
156
+ attn_metadata .attn_bias = att_masks
178
157
179
158
query = query .movedim (0 , query .dim () - 2 )
180
159
key = key .movedim (0 , key .dim () - 2 )
181
160
value = value .movedim (0 , value .dim () - 2 )
182
161
183
162
start = 0
184
- out = torch .empty ((num_tokens , self .num_heads , self .head_size ),
185
- dtype = query .dtype )
186
- for prompt_len , mask in zip (prefill_meta .prompt_lens ,
187
- prefill_meta .attn_bias ):
163
+ output = torch .empty (
164
+ (num_tokens , self .num_heads , self .head_size ),
165
+ dtype = query .dtype )
166
+ for prompt_len , mask in zip (attn_metadata .prompt_lens ,
167
+ attn_metadata .attn_bias ):
188
168
end = start + prompt_len
189
169
sub_out = scaled_dot_product_attention (
190
170
query [:, start :end , :],
@@ -194,32 +174,28 @@ def forward(
194
174
dropout_p = 0.0 ,
195
175
is_causal = not self .need_mask ,
196
176
scale = self .scale ).movedim (query .dim () - 2 , 0 )
197
- out [start :end , :, :] = sub_out
177
+ output [start :end , :, :] = sub_out
198
178
start = end
199
- assert out .shape == output [:num_prefill_tokens ].shape
200
- output [:num_prefill_tokens ] = out
201
179
else :
202
180
# prefix-enabled attention
203
181
raise RuntimeError (
204
182
"Torch SDPA backend doesn't support prefix decoding." )
205
183
206
- if decode_meta := attn_metadata . decode_metadata :
184
+ else :
207
185
# Decoding run.
208
- out = PagedAttention .forward_decode (
209
- decode_query ,
186
+ output = PagedAttention .forward_decode (
187
+ query ,
210
188
key_cache ,
211
189
value_cache ,
212
- decode_meta .block_tables ,
213
- decode_meta .context_lens ,
214
- decode_meta .max_context_len ,
190
+ attn_metadata .block_tables ,
191
+ attn_metadata .context_lens ,
192
+ attn_metadata .max_context_len ,
215
193
attn_metadata .kv_cache_dtype ,
216
194
self .num_kv_heads ,
217
195
self .scale ,
218
196
self .alibi_slopes ,
219
197
kv_scale ,
220
198
)
221
- assert out .shape == output [num_prefill_tokens :].shape
222
- output [num_prefill_tokens :]
223
199
224
200
# Reshape the output tensor.
225
201
return output .view (- 1 , self .num_heads * self .head_size )
@@ -241,7 +217,7 @@ def _make_alibi_bias(
241
217
bias = bias [None , :] - bias [:, None ]
242
218
243
219
num_heads = alibi_slopes .shape [0 ]
244
- bias = bias [None , :].expand ( num_heads , prompt_len , prompt_len )
220
+ bias = bias [None , :].repeat (( num_heads , 1 , 1 ) )
245
221
bias .mul_ (alibi_slopes [:, None , None ])
246
222
inf_mask = torch .empty (
247
223
(1 , prompt_len , prompt_len ),
0 commit comments