@@ -133,23 +133,26 @@ def __init__(
133
133
134
134
def forward (
135
135
self ,
136
- input_pos : torch .Tensor ,
136
+ input_pos : Optional [ torch .Tensor ] ,
137
137
q : torch .Tensor , # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
138
138
k : torch .Tensor , # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
139
139
v : torch .Tensor , # (bs, n_local_kv_heads, seqlen, head_dim)
140
140
bsz ,
141
141
seqlen ,
142
142
mask : torch .Tensor ,
143
143
) -> torch .Tensor :
144
- if self .enable_dynamic_shape :
145
- start_pos = input_pos [- 1 ].item ()
146
- torch ._check_is_size (start_pos )
147
- torch ._check (start_pos < self .max_context_len )
148
- seq_length = q .size (2 )
149
- # pyre-ignore: Incompatible parameter type [6]
150
- attn_mask = mask .narrow (0 , start_pos , seq_length )
144
+ if input_pos is None : # No kv cache
145
+ attn_mask = mask [:seqlen , :seqlen ]
151
146
else :
152
- attn_mask = mask [None , None , input_pos ]
147
+ if self .enable_dynamic_shape :
148
+ start_pos = input_pos [- 1 ].item ()
149
+ torch ._check_is_size (start_pos )
150
+ torch ._check (start_pos < self .max_context_len )
151
+ seq_length = q .size (2 )
152
+ # pyre-ignore: Incompatible parameter type [6]
153
+ attn_mask = mask .narrow (0 , start_pos , seq_length )
154
+ else :
155
+ attn_mask = mask [None , None , input_pos ]
153
156
154
157
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
155
158
# can natively support GQA now. But needs enable_gqa=True
@@ -218,13 +221,13 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
218
221
self .head_dim ,
219
222
args .enable_dynamic_shape ,
220
223
)
221
- self .SDPA = SDPA (
222
- dim = self .n_local_heads * self .head_dim ,
223
- head_dim = self .head_dim ,
224
- n_rep = self .n_rep ,
225
- max_context_len = self .max_context_len ,
226
- enable_dynamic_shape = args .enable_dynamic_shape ,
227
- )
224
+ self .SDPA = SDPA (
225
+ dim = self .n_local_heads * self .head_dim ,
226
+ head_dim = self .head_dim ,
227
+ n_rep = self .n_rep ,
228
+ max_context_len = self .max_context_len ,
229
+ enable_dynamic_shape = args .enable_dynamic_shape ,
230
+ )
228
231
229
232
def forward (
230
233
self ,
@@ -257,21 +260,5 @@ def forward(
257
260
if self .use_kv_cache :
258
261
assert input_pos is not None
259
262
k , v = self .kv_cache .update (input_pos , k , v )
260
- output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
261
- return self .wo (output ), None
262
-
263
- # grouped multiquery attention: expand out keys and values
264
- k = k .repeat_interleave (self .n_rep , dim = 1 )
265
- v = v .repeat_interleave (self .n_rep , dim = 1 )
266
-
267
- assert hasattr (self , "mask" )
268
-
269
- mask = self .mask [:seqlen , :seqlen ]
270
-
271
- output = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
272
-
273
- output = output .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , - 1 )
274
-
275
- output = self .wo (output )
276
-
277
- return output , None
263
+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
264
+ return self .wo (output ), None
0 commit comments