Skip to content

Commit e55a404

Browse files
authored
Use broadcasting trick for KV update (apple#972)
* Use vmap and dynamic_update_slice for KV update * Broadcasting trick * Simplify the impl per @markblee's suggestion * comments
1 parent b955187 commit e55a404

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

axlearn/common/attention.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ def extend_step(
808808
Args:
809809
cached_states: A `NestedTensor` object containing tensors which are the results of
810810
previous attentions, and index used for fast decoding. Contains "key" and "value" of
811-
shape [batch, num_heads, per_head_dim, target_length], and a Tensor "time_step" of
811+
shape [batch, source_length, num_heads, per_head_dim], and a Tensor "time_step" of
812812
shape [batch].
813813
query: Tensor of shape [batch, steps, target_dim] corresponding to query vector starting
814814
at "time_step" indices.
@@ -842,25 +842,25 @@ def extend_step(
842842
q_proj, k_proj, v_proj = self.forward(query, **kv_kwargs, query_positions=query_positions)
843843
updated_state = dict(time_step=time_step + num_query_steps)
844844
if kv_state is None:
845-
# Update the cache via dynamic slice. [B, S, N, H].
845+
# Update the cache via one-hot broadcast and addition.
846+
# NB: Cache updates can also be done via dynamic slice update. However it was observed
847+
# that RLHF training got stuck in some cases.
848+
# TODO(ds-hwang): Investigate the root cause.
846849
cached_key = cached_states["key"]
847850
cached_value = cached_states["value"]
848851

849-
# Ensure that we accumulate using the original dtype.
850-
k_proj = k_proj.astype(cached_key.dtype)
851-
v_proj = v_proj.astype(cached_value.dtype)
852-
853-
# TODO(dhwang2): jax.lax.dynamic_update_slice_in_dim is generally faster than advanced
854-
# indexing, but an unusual slowdown was observed, with RLHF sampling taking up to
855-
# 3 hours per run. Investigate and fix it.
856-
# Note: All X_idx are small, so generating them on-demand is not costly.
857-
b, _, n, h = cached_key.shape
858-
b_idx = jnp.arange(b)[:, None, None, None]
859-
t_idx = (jnp.arange(k_proj.shape[1])[None] + time_step[:, None])[:, :, None, None]
860-
n_idx = jnp.arange(n)[None, None, :, None]
861-
h_idx = jnp.arange(h)[None, None, None, :]
862-
k_proj = cached_key.at[b_idx, t_idx, n_idx, h_idx].set(k_proj)
863-
v_proj = cached_value.at[b_idx, t_idx, n_idx, h_idx].set(v_proj)
852+
source_len = cached_key.shape[1]
853+
854+
# Create a dispatch matrix of shape [B, T=step, S].
855+
oh_indices = jax.nn.one_hot(
856+
time_step[:, None] + jnp.arange(num_query_steps), source_len, dtype=k_proj.dtype
857+
)
858+
# Create a mask of shape [B, S, 1, 1].
859+
negated_oh_indices = (1 - oh_indices.sum(axis=1))[..., None, None]
860+
k_proj = jnp.einsum("bt...,bts->bs...", k_proj, oh_indices)
861+
v_proj = jnp.einsum("bt...,bts->bs...", v_proj, oh_indices)
862+
k_proj = cached_key * negated_oh_indices + k_proj
863+
v_proj = cached_value * negated_oh_indices + v_proj
864864

865865
updated_state.update(key=k_proj, value=v_proj)
866866
return updated_state, self.Output(query=q_proj, key=k_proj, value=v_proj)

0 commit comments

Comments
 (0)