@@ -808,7 +808,7 @@ def extend_step(
808808 Args:
809809 cached_states: A `NestedTensor` object containing tensors which are the results of
810810 previous attentions, and index used for fast decoding. Contains "key" and "value" of
811- shape [batch, num_heads, per_head_dim, target_length ], and a Tensor "time_step" of
811+ shape [batch, source_length, num_heads, per_head_dim ], and a Tensor "time_step" of
812812 shape [batch].
813813 query: Tensor of shape [batch, steps, target_dim] corresponding to query vector starting
814814 at "time_step" indices.
@@ -842,25 +842,25 @@ def extend_step(
842842 q_proj , k_proj , v_proj = self .forward (query , ** kv_kwargs , query_positions = query_positions )
843843 updated_state = dict (time_step = time_step + num_query_steps )
844844 if kv_state is None :
845- # Update the cache via dynamic slice. [B, S, N, H].
845+ # Update the cache via one-hot broadcast and addition.
846+ # NB: Cache updates can also be done via dynamic slice update. However it was observed
847+ # that RLHF training got stuck in some cases.
848+ # TODO(ds-hwang): Investigate the root cause.
846849 cached_key = cached_states ["key" ]
847850 cached_value = cached_states ["value" ]
848851
849- # Ensure that we accumulate using the original dtype.
850- k_proj = k_proj .astype (cached_key .dtype )
851- v_proj = v_proj .astype (cached_value .dtype )
852-
853- # TODO(dhwang2): jax.lax.dynamic_update_slice_in_dim is generally faster than advanced
854- # indexing, but an unusual slowdown was observed, with RLHF sampling taking up to
855- # 3 hours per run. Investigate and fix it.
856- # Note: All X_idx are small, so generating them on-demand is not costly.
857- b , _ , n , h = cached_key .shape
858- b_idx = jnp .arange (b )[:, None , None , None ]
859- t_idx = (jnp .arange (k_proj .shape [1 ])[None ] + time_step [:, None ])[:, :, None , None ]
860- n_idx = jnp .arange (n )[None , None , :, None ]
861- h_idx = jnp .arange (h )[None , None , None , :]
862- k_proj = cached_key .at [b_idx , t_idx , n_idx , h_idx ].set (k_proj )
863- v_proj = cached_value .at [b_idx , t_idx , n_idx , h_idx ].set (v_proj )
852+ source_len = cached_key .shape [1 ]
853+
854+ # Create a dispatch matrix of shape [B, T=step, S].
855+ oh_indices = jax .nn .one_hot (
856+ time_step [:, None ] + jnp .arange (num_query_steps ), source_len , dtype = k_proj .dtype
857+ )
858+ # Create a mask of shape [B, S, 1, 1].
859+ negated_oh_indices = (1 - oh_indices .sum (axis = 1 ))[..., None , None ]
860+ k_proj = jnp .einsum ("bt...,bts->bs..." , k_proj , oh_indices )
861+ v_proj = jnp .einsum ("bt...,bts->bs..." , v_proj , oh_indices )
862+ k_proj = cached_key * negated_oh_indices + k_proj
863+ v_proj = cached_value * negated_oh_indices + v_proj
864864
865865 updated_state .update (key = k_proj , value = v_proj )
866866 return updated_state , self .Output (query = q_proj , key = k_proj , value = v_proj )
0 commit comments