@@ -741,13 +741,17 @@ def init_states(
741
741
batch , max_len = key .shape [:2 ]
742
742
chex .assert_equal_shape ((key , value ))
743
743
744
+ # NB: key and value in init_state are transposed so that source_length is in the last
745
+ # dimension as a TPU fusion optimization.
746
+ # Reference:
747
+ # https://github.com/google-research/t5x/blob/4d94d8bf41230d492e15e255c9888b5bfd9a5ee8/t5x/examples/t5/layers.py#L215
744
748
init_state .update (
745
749
key = jnp .zeros (
746
- shape = (batch , max_len , self .num_kv_heads , cfg .per_head_dim ),
750
+ shape = (batch , self .num_kv_heads , cfg .per_head_dim , max_len ),
747
751
dtype = dtype ,
748
752
),
749
753
value = jnp .zeros (
750
- shape = (batch , max_len , self .num_kv_heads , cfg .per_head_dim ),
754
+ shape = (batch , self .num_kv_heads , cfg .per_head_dim , max_len ),
751
755
dtype = dtype ,
752
756
),
753
757
)
@@ -808,7 +812,7 @@ def extend_step(
808
812
Args:
809
813
cached_states: A `NestedTensor` object containing tensors which are the results of
810
814
previous attentions, and index used for fast decoding. Contains "key" and "value" of
811
- shape [batch, source_length, num_heads, per_head_dim], and a Tensor "time_step" of
815
+ shape [batch, num_heads, per_head_dim, source_length ], and a Tensor "time_step" of
812
816
shape [batch].
813
817
query: Tensor of shape [batch, steps, target_dim] corresponding to query vector starting
814
818
at "time_step" indices.
@@ -841,6 +845,7 @@ def extend_step(
841
845
# Project inputs to key, value and query. Each has shape [B, steps, N, H].
842
846
q_proj , k_proj , v_proj = self .forward (query , ** kv_kwargs , query_positions = query_positions )
843
847
updated_state = dict (time_step = time_step + num_query_steps )
848
+
844
849
if kv_state is None :
845
850
# Update the cache via one-hot broadcast and addition.
846
851
# NB: Cache updates can also be done via dynamic slice update. However it was observed
@@ -849,21 +854,32 @@ def extend_step(
849
854
cached_key = cached_states ["key" ]
850
855
cached_value = cached_states ["value" ]
851
856
852
- source_len = cached_key .shape [1 ]
857
+ source_len = cached_key .shape [- 1 ]
858
+
859
+ # [B, T, N, H] --> [B, N, H, T].
860
+ k_proj = jnp .einsum ("btnh->bnht" , k_proj )
861
+ v_proj = jnp .einsum ("btnh->bnht" , v_proj )
853
862
854
863
# Create a dispatch matrix of shape [B, T=step, S].
855
864
oh_indices = jax .nn .one_hot (
856
865
time_step [:, None ] + jnp .arange (num_query_steps ), source_len , dtype = cached_key .dtype
857
866
)
858
- # Create a mask of shape [B, S, 1, 1].
859
- negated_oh_indices = (1 - oh_indices .sum (axis = 1 ))[..., None , None ]
860
- k_proj = jnp .einsum ("bt...,bts->bs..." , k_proj , oh_indices )
861
- v_proj = jnp .einsum ("bt...,bts->bs..." , v_proj , oh_indices )
867
+ # Create a mask of shape [B, 1, 1, S].
868
+ negated_oh_indices = (1 - oh_indices .sum (axis = 1 ))[:, None , None , :]
869
+
870
+ k_proj = jnp .einsum ("b...t,bts->b...s" , k_proj , oh_indices )
871
+ v_proj = jnp .einsum ("b...t,bts->b...s" , v_proj , oh_indices )
872
+
862
873
# Ensure that we accumulate using the original dtype.
863
- k_proj = cached_key * negated_oh_indices + k_proj .astype (cached_key .dtype )
864
- v_proj = cached_value * negated_oh_indices + v_proj .astype (cached_value .dtype )
874
+ cached_key = cached_key * negated_oh_indices + k_proj .astype (cached_key .dtype )
875
+ cached_value = cached_value * negated_oh_indices + v_proj .astype (cached_value .dtype )
876
+
877
+ updated_state .update (key = cached_key , value = cached_value )
878
+
879
+ # [B, S, N, H]
880
+ k_proj = jnp .einsum ("bnhs->bsnh" , cached_key )
881
+ v_proj = jnp .einsum ("bnhs->bsnh" , cached_value )
865
882
866
- updated_state .update (key = k_proj , value = v_proj )
867
883
return updated_state , self .Output (query = q_proj , key = k_proj , value = v_proj )
868
884
869
885
0 commit comments