@@ -136,7 +136,6 @@ def call(
136136 cache_update_index = None ,
137137 training = None ,
138138 ):
139- seq_len = ops .shape (hidden_states )[1 ]
140139 start_index = (
141140 cache_update_index if cache_update_index is not None else 0
142141 )
@@ -170,67 +169,24 @@ def _compute_key_value(x):
170169 return key , value
171170
172171 if cache is not None :
173- cache_k = cache [:, 0 , ...]
174- cache_v = cache [:, 1 , ...]
175-
172+ key_cache = cache [:, 0 , ...]
173+ value_cache = cache [:, 1 , ...]
174+ if cache_update_index is None :
175+ key = key_cache
176+ value = value_cache
177+ else :
178+ key_update , value_update = _compute_key_value (hidden_states )
179+ start = [0 , cache_update_index , 0 , 0 ]
180+ key = ops .slice_update (key_cache , start , key_update )
181+ value = ops .slice_update (value_cache , start , value_update )
182+ cache = ops .stack ((key , value ), axis = 1 )
183+ else :
176184 if cache_update_index is not None :
177- # Compute the new keys and values
178- key , value = _compute_key_value (hidden_states )
179-
180- # Cache is a rotating buffer, we want to warp around if
181- # the sequence length exceeds the sliding window.
182- update_end_index = (
183- cache_update_index + seq_len - 1
184- ) % self ._sliding_window + 1
185- update_end_index = ops .cast (update_end_index , "int32" )
186- cache_update_index = cache_update_index % self ._sliding_window
187- update_start_index = ops .cond (
188- update_end_index > cache_update_index ,
189- lambda : ops .cast (cache_update_index , "int32" ),
190- lambda : ops .cast (0 , "int32" ),
191- )
192- # Also note that the update step below assumes that the
193- # sequence length is always one when `cache_update_index != 0`.
194- # This is necessary to support XLA compilation. Ideally, we
195- # would want to use
196- # `key[:, -(update_end_index - update_start_index):, ...]`
197- # as the update but updating using a dynamic slice gives an
198- # XLA compilation error in TensorFlow.
199- # Passing a sequence of length > 1 with cache update might give
200- # incorrect results (since there is no way to determine how
201- # many most recent tokens are to be saved if the tokens exceed
202- # the sliding window length).
203- cache_k = ops .slice_update (
204- cache_k ,
205- [0 , update_start_index , 0 , 0 ],
206- # We slice the keys and values since if the user has passed
207- # a sequence of length > `self._sliding_window`. We want to
208- # prefill the cache using just the most recent values in the
209- # sliding window.
210- ops .cast (
211- key [:, - self ._sliding_window :, ...], cache_k .dtype
212- ),
213- )
214- cache_v = ops .slice_update (
215- cache_v ,
216- [0 , update_start_index , 0 , 0 ],
217- ops .cast (
218- value [:, - self ._sliding_window :, ...], cache_v .dtype
219- ),
185+ raise ValueError (
186+ "`cache_update_index` should not be set if `cache` is "
187+ f"`None`. Received: cache={ cache } , "
188+ f"cache_update_index={ cache_update_index } "
220189 )
221- cache = ops .stack ([cache_k , cache_v ], axis = 1 )
222-
223- # Get the required keys and values from the cache.
224- # Since we expect the user to pass a fixed-size cache, we just
225- # pick the first few slices up-to and including the newly computed
226- # keys and values.
227- cache_k = cache_k [:, :update_end_index , ...]
228- cache_v = cache_v [:, :update_end_index , ...]
229-
230- key = ops .cast (cache_k , dtype = self .compute_dtype )
231- value = ops .cast (cache_v , dtype = self .compute_dtype )
232- else :
233- # Compute keys and values
234190 key , value = _compute_key_value (hidden_states )
235191
236192 # [batch_shape, seq_len, num_key_value_heads, head_dim]
@@ -260,7 +216,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
260216 return self ._softmax (attention_scores )
261217
262218 def _compute_attention (self , query , key , value , attention_mask = None ):
263- attention_scores = ops .einsum (self ._dot_product_equation , key , query )
219+ attention_scores = ops .einsum (self ._dot_product_equation , query , key )
264220
265221 norm_factor = ops .sqrt (ops .cast (self ._head_dim , self .compute_dtype ))
266222
0 commit comments