Skip to content

Commit 19b0b89

Browse files
Remove sliding window attention from Mistral's attention layer
JAX complains about dynamic slicing when compiled with XLA. This is unavoidable since, at runtime, the slice of the current key/value array to use for that iteration is determined by `cache_update_index` which is itself a JAX `TracedArray`. Any workaround would lead to using dynamic shapes at some point. Hence, I had to remove this and instead use vanilla caching for now. For some reason, TensorFlow doesn't complain with XLA. I think this might be because TensorFlow is as stringent about statis shapes as JAX. In any case, adding sliding window attention that is XLA compatible is a story for the future.
1 parent 2e2e2e5 commit 19b0b89

File tree

1 file changed

+17
-61
lines changed

1 file changed

+17
-61
lines changed

keras_nlp/models/mistral/mistral_attention.py

Lines changed: 17 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def call(
136136
cache_update_index=None,
137137
training=None,
138138
):
139-
seq_len = ops.shape(hidden_states)[1]
140139
start_index = (
141140
cache_update_index if cache_update_index is not None else 0
142141
)
@@ -170,67 +169,24 @@ def _compute_key_value(x):
170169
return key, value
171170

172171
if cache is not None:
173-
cache_k = cache[:, 0, ...]
174-
cache_v = cache[:, 1, ...]
175-
172+
key_cache = cache[:, 0, ...]
173+
value_cache = cache[:, 1, ...]
174+
if cache_update_index is None:
175+
key = key_cache
176+
value = value_cache
177+
else:
178+
key_update, value_update = _compute_key_value(hidden_states)
179+
start = [0, cache_update_index, 0, 0]
180+
key = ops.slice_update(key_cache, start, key_update)
181+
value = ops.slice_update(value_cache, start, value_update)
182+
cache = ops.stack((key, value), axis=1)
183+
else:
176184
if cache_update_index is not None:
177-
# Compute the new keys and values
178-
key, value = _compute_key_value(hidden_states)
179-
180-
# Cache is a rotating buffer, we want to warp around if
181-
# the sequence length exceeds the sliding window.
182-
update_end_index = (
183-
cache_update_index + seq_len - 1
184-
) % self._sliding_window + 1
185-
update_end_index = ops.cast(update_end_index, "int32")
186-
cache_update_index = cache_update_index % self._sliding_window
187-
update_start_index = ops.cond(
188-
update_end_index > cache_update_index,
189-
lambda: ops.cast(cache_update_index, "int32"),
190-
lambda: ops.cast(0, "int32"),
191-
)
192-
# Also note that the update step below assumes that the
193-
# sequence length is always one when `cache_update_index != 0`.
194-
# This is necessary to support XLA compilation. Ideally, we
195-
# would want to use
196-
# `key[:, -(update_end_index - update_start_index):, ...]`
197-
# as the update but updating using a dynamic slice gives an
198-
# XLA compilation error in TensorFlow.
199-
# Passing a sequence of length > 1 with cache update might give
200-
# incorrect results (since there is no way to determine how
201-
# many most recent tokens are to be saved if the tokens exceed
202-
# the sliding window length).
203-
cache_k = ops.slice_update(
204-
cache_k,
205-
[0, update_start_index, 0, 0],
206-
# We slice the keys and values since if the user has passed
207-
# a sequence of length > `self._sliding_window`. We want to
208-
# prefill the cache using just the most recent values in the
209-
# sliding window.
210-
ops.cast(
211-
key[:, -self._sliding_window :, ...], cache_k.dtype
212-
),
213-
)
214-
cache_v = ops.slice_update(
215-
cache_v,
216-
[0, update_start_index, 0, 0],
217-
ops.cast(
218-
value[:, -self._sliding_window :, ...], cache_v.dtype
219-
),
185+
raise ValueError(
186+
"`cache_update_index` should not be set if `cache` is "
187+
f"`None`. Received: cache={cache}, "
188+
f"cache_update_index={cache_update_index}"
220189
)
221-
cache = ops.stack([cache_k, cache_v], axis=1)
222-
223-
# Get the required keys and values from the cache.
224-
# Since we expect the user to pass a fixed-size cache, we just
225-
# pick the first few slices up-to and including the newly computed
226-
# keys and values.
227-
cache_k = cache_k[:, :update_end_index, ...]
228-
cache_v = cache_v[:, :update_end_index, ...]
229-
230-
key = ops.cast(cache_k, dtype=self.compute_dtype)
231-
value = ops.cast(cache_v, dtype=self.compute_dtype)
232-
else:
233-
# Compute keys and values
234190
key, value = _compute_key_value(hidden_states)
235191

236192
# [batch_shape, seq_len, num_key_value_heads, head_dim]
@@ -260,7 +216,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
260216
return self._softmax(attention_scores)
261217

262218
def _compute_attention(self, query, key, value, attention_mask=None):
263-
attention_scores = ops.einsum(self._dot_product_equation, key, query)
219+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
264220

265221
norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype))
266222

0 commit comments

Comments
 (0)