@@ -4183,7 +4183,10 @@ def _update_layer_kwargs(
4183
4183
layer_kwargs : dict [str , Any ],
4184
4184
* ,
4185
4185
all_layer_outputs : list [BaseTransformerLayer .Output ],
4186
+ external_self_attention_kv_state : Optional [KVState ] = None ,
4186
4187
):
4188
+ del external_self_attention_kv_state
4189
+
4187
4190
layer_index = len (all_layer_outputs )
4188
4191
if layer_index == 1 :
4189
4192
layer_kwargs ["self_attention_kv_state" ] = all_layer_outputs [- 1 ].self_attention_kv_state
@@ -4586,6 +4589,51 @@ def test_skip_connection(self):
4586
4589
0.0 ,
4587
4590
)
4588
4591
4592
+ def test_passthrough_update_layer_kwargs (self ):
4593
+ num_heads = 2
4594
+ input_dim = 4
4595
+ hidden_dim = 8
4596
+ num_layers = 3
4597
+
4598
+ cfg = StackedTransformerLayer .default_config ().set (name = "test" )
4599
+ cfg .input_dim = input_dim
4600
+ cfg .num_layers = num_layers
4601
+
4602
+ transformer_cfg = TransformerLayer .default_config ()
4603
+ transformer_cfg .self_attention .attention .num_heads = num_heads
4604
+ transformer_cfg .feed_forward .hidden_dim = hidden_dim
4605
+ cfg .layer = transformer_cfg
4606
+
4607
+ layer : StackedTransformerLayer = cfg .instantiate (parent = None )
4608
+ state = layer .initialize_parameters_recursively (prng_key = jax .random .PRNGKey (123 ))
4609
+
4610
+ input_all_layer_outputs = [BaseTransformerLayer .Output (data = jnp .ones ([2 , 3 ]))]
4611
+ expected_all_layer_outputs = [BaseTransformerLayer .Output (data = jnp .ones ([2 , 3 ]))]
4612
+ k_proj = jnp .zeros ([3 , 3 ])
4613
+ v_proj = jnp .ones ([3 , 3 ])
4614
+ input_self_attention_kv_state = KVState (k_proj = k_proj , v_proj = v_proj )
4615
+ expected_self_attention_kv_state = KVState (k_proj = k_proj , v_proj = v_proj )
4616
+ F (
4617
+ layer ,
4618
+ prng_key = jax .random .PRNGKey (0 ),
4619
+ state = state ,
4620
+ inputs = dict (
4621
+ layer_kwargs = {},
4622
+ all_layer_outputs = [],
4623
+ external_self_attention_kv_state = input_self_attention_kv_state ,
4624
+ ),
4625
+ method = "_update_layer_kwargs" ,
4626
+ is_training = True ,
4627
+ )
4628
+ self .assertNestedAllClose (
4629
+ input_all_layer_outputs ,
4630
+ expected_all_layer_outputs ,
4631
+ )
4632
+ self .assertNestedAllClose (
4633
+ input_self_attention_kv_state ,
4634
+ expected_self_attention_kv_state ,
4635
+ )
4636
+
4589
4637
def test_update_layer_kwargs (self ):
4590
4638
batch_size = 2
4591
4639
seq_len = 6
0 commit comments