Skip to content

Commit b130416

Browse files
authored
Transpose kv cache for better decode performance (apple#979)
1 parent 48bf488 commit b130416

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

axlearn/common/attention.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -741,13 +741,17 @@ def init_states(
741741
batch, max_len = key.shape[:2]
742742
chex.assert_equal_shape((key, value))
743743

744+
# NB: key and value in init_state are transposed so that source_length is in the last
745+
# dimension as a TPU fusion optimization.
746+
# Reference:
747+
# https://github.com/google-research/t5x/blob/4d94d8bf41230d492e15e255c9888b5bfd9a5ee8/t5x/examples/t5/layers.py#L215
744748
init_state.update(
745749
key=jnp.zeros(
746-
shape=(batch, max_len, self.num_kv_heads, cfg.per_head_dim),
750+
shape=(batch, self.num_kv_heads, cfg.per_head_dim, max_len),
747751
dtype=dtype,
748752
),
749753
value=jnp.zeros(
750-
shape=(batch, max_len, self.num_kv_heads, cfg.per_head_dim),
754+
shape=(batch, self.num_kv_heads, cfg.per_head_dim, max_len),
751755
dtype=dtype,
752756
),
753757
)
@@ -808,7 +812,7 @@ def extend_step(
808812
Args:
809813
cached_states: A `NestedTensor` object containing tensors which are the results of
810814
previous attentions, and index used for fast decoding. Contains "key" and "value" of
811-
shape [batch, source_length, num_heads, per_head_dim], and a Tensor "time_step" of
815+
shape [batch, num_heads, per_head_dim, source_length], and a Tensor "time_step" of
812816
shape [batch].
813817
query: Tensor of shape [batch, steps, target_dim] corresponding to query vector starting
814818
at "time_step" indices.
@@ -841,6 +845,7 @@ def extend_step(
841845
# Project inputs to key, value and query. Each has shape [B, steps, N, H].
842846
q_proj, k_proj, v_proj = self.forward(query, **kv_kwargs, query_positions=query_positions)
843847
updated_state = dict(time_step=time_step + num_query_steps)
848+
844849
if kv_state is None:
845850
# Update the cache via one-hot broadcast and addition.
846851
# NB: Cache updates can also be done via dynamic slice update. However it was observed
@@ -849,21 +854,32 @@ def extend_step(
849854
cached_key = cached_states["key"]
850855
cached_value = cached_states["value"]
851856

852-
source_len = cached_key.shape[1]
857+
source_len = cached_key.shape[-1]
858+
859+
# [B, T, N, H] --> [B, N, H, T].
860+
k_proj = jnp.einsum("btnh->bnht", k_proj)
861+
v_proj = jnp.einsum("btnh->bnht", v_proj)
853862

854863
# Create a dispatch matrix of shape [B, T=step, S].
855864
oh_indices = jax.nn.one_hot(
856865
time_step[:, None] + jnp.arange(num_query_steps), source_len, dtype=cached_key.dtype
857866
)
858-
# Create a mask of shape [B, S, 1, 1].
859-
negated_oh_indices = (1 - oh_indices.sum(axis=1))[..., None, None]
860-
k_proj = jnp.einsum("bt...,bts->bs...", k_proj, oh_indices)
861-
v_proj = jnp.einsum("bt...,bts->bs...", v_proj, oh_indices)
867+
# Create a mask of shape [B, 1, 1, S].
868+
negated_oh_indices = (1 - oh_indices.sum(axis=1))[:, None, None, :]
869+
870+
k_proj = jnp.einsum("b...t,bts->b...s", k_proj, oh_indices)
871+
v_proj = jnp.einsum("b...t,bts->b...s", v_proj, oh_indices)
872+
862873
# Ensure that we accumulate using the original dtype.
863-
k_proj = cached_key * negated_oh_indices + k_proj.astype(cached_key.dtype)
864-
v_proj = cached_value * negated_oh_indices + v_proj.astype(cached_value.dtype)
874+
cached_key = cached_key * negated_oh_indices + k_proj.astype(cached_key.dtype)
875+
cached_value = cached_value * negated_oh_indices + v_proj.astype(cached_value.dtype)
876+
877+
updated_state.update(key=cached_key, value=cached_value)
878+
879+
# [B, S, N, H]
880+
k_proj = jnp.einsum("bnhs->bsnh", cached_key)
881+
v_proj = jnp.einsum("bnhs->bsnh", cached_value)
865882

866-
updated_state.update(key=k_proj, value=v_proj)
867883
return updated_state, self.Output(query=q_proj, key=k_proj, value=v_proj)
868884

869885

axlearn/common/attention_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2769,7 +2769,7 @@ def _test_prefill_states(
27692769
self.assertTrue(jnp.all(time_step == initial_states["i_proj"]["time_step"]))
27702770
for proj in ["key", "value"]:
27712771
self.assertEqual(
2772-
(batch_size, tgt_len, num_kv_heads or num_heads, model_dim // num_heads),
2772+
(batch_size, num_kv_heads or num_heads, model_dim // num_heads, tgt_len),
27732773
initial_states["i_proj"][proj].shape,
27742774
)
27752775
self.assertEqual(

0 commit comments

Comments
 (0)