@@ -741,13 +741,17 @@ def init_states(
741741 batch , max_len = key .shape [:2 ]
742742 chex .assert_equal_shape ((key , value ))
743743
744+ # NB: key and value in init_state are transposed so that source_length is in the last
745+ # dimension as a TPU fusion optimization.
746+ # Reference:
747+ # https://github.com/google-research/t5x/blob/4d94d8bf41230d492e15e255c9888b5bfd9a5ee8/t5x/examples/t5/layers.py#L215
744748 init_state .update (
745749 key = jnp .zeros (
746- shape = (batch , max_len , self .num_kv_heads , cfg .per_head_dim ),
750+ shape = (batch , self .num_kv_heads , cfg .per_head_dim , max_len ),
747751 dtype = dtype ,
748752 ),
749753 value = jnp .zeros (
750- shape = (batch , max_len , self .num_kv_heads , cfg .per_head_dim ),
754+ shape = (batch , self .num_kv_heads , cfg .per_head_dim , max_len ),
751755 dtype = dtype ,
752756 ),
753757 )
@@ -808,7 +812,7 @@ def extend_step(
808812 Args:
809813 cached_states: A `NestedTensor` object containing tensors which are the results of
810814 previous attentions, and index used for fast decoding. Contains "key" and "value" of
811- shape [batch, source_length, num_heads, per_head_dim], and a Tensor "time_step" of
815+ shape [batch, num_heads, per_head_dim, source_length ], and a Tensor "time_step" of
812816 shape [batch].
813817 query: Tensor of shape [batch, steps, target_dim] corresponding to query vector starting
814818 at "time_step" indices.
@@ -841,6 +845,7 @@ def extend_step(
841845 # Project inputs to key, value and query. Each has shape [B, steps, N, H].
842846 q_proj , k_proj , v_proj = self .forward (query , ** kv_kwargs , query_positions = query_positions )
843847 updated_state = dict (time_step = time_step + num_query_steps )
848+
844849 if kv_state is None :
845850 # Update the cache via one-hot broadcast and addition.
846851 # NB: Cache updates can also be done via dynamic slice update. However it was observed
@@ -849,21 +854,32 @@ def extend_step(
849854 cached_key = cached_states ["key" ]
850855 cached_value = cached_states ["value" ]
851856
852- source_len = cached_key .shape [1 ]
857+ source_len = cached_key .shape [- 1 ]
858+
859+ # [B, T, N, H] --> [B, N, H, T].
860+ k_proj = jnp .einsum ("btnh->bnht" , k_proj )
861+ v_proj = jnp .einsum ("btnh->bnht" , v_proj )
853862
854863 # Create a dispatch matrix of shape [B, T=step, S].
855864 oh_indices = jax .nn .one_hot (
856865 time_step [:, None ] + jnp .arange (num_query_steps ), source_len , dtype = cached_key .dtype
857866 )
858- # Create a mask of shape [B, S, 1, 1].
859- negated_oh_indices = (1 - oh_indices .sum (axis = 1 ))[..., None , None ]
860- k_proj = jnp .einsum ("bt...,bts->bs..." , k_proj , oh_indices )
861- v_proj = jnp .einsum ("bt...,bts->bs..." , v_proj , oh_indices )
867+ # Create a mask of shape [B, 1, 1, S].
868+ negated_oh_indices = (1 - oh_indices .sum (axis = 1 ))[:, None , None , :]
869+
870+ k_proj = jnp .einsum ("b...t,bts->b...s" , k_proj , oh_indices )
871+ v_proj = jnp .einsum ("b...t,bts->b...s" , v_proj , oh_indices )
872+
862873 # Ensure that we accumulate using the original dtype.
863- k_proj = cached_key * negated_oh_indices + k_proj .astype (cached_key .dtype )
864- v_proj = cached_value * negated_oh_indices + v_proj .astype (cached_value .dtype )
874+ cached_key = cached_key * negated_oh_indices + k_proj .astype (cached_key .dtype )
875+ cached_value = cached_value * negated_oh_indices + v_proj .astype (cached_value .dtype )
876+
877+ updated_state .update (key = cached_key , value = cached_value )
878+
879+ # [B, S, N, H]
880+ k_proj = jnp .einsum ("bnhs->bsnh" , cached_key )
881+ v_proj = jnp .einsum ("bnhs->bsnh" , cached_value )
865882
866- updated_state .update (key = k_proj , value = v_proj )
867883 return updated_state , self .Output (query = q_proj , key = k_proj , value = v_proj )
868884
869885
0 commit comments