116116 Nested ,
117117 PartitionSpec ,
118118 Tensor ,
119+ TensorSpec ,
119120 VDict ,
120121 as_tensor ,
121122 flatten_items ,
@@ -1472,7 +1473,10 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], ex
14721473 inputs = dict (query = query ),
14731474 )
14741475
1475- cache_state = layer .init_states (target_batch_size = batch_size , target_max_len = tgt_len )
1476+ cache_state , init_output = layer .init_states (
1477+ time_step = None , query = TensorSpec ([batch_size , tgt_len ])
1478+ )
1479+ self .assertIsNone (init_output )
14761480 step_querys = []
14771481 step_keys = step_values = None
14781482 for t in range (0 , tgt_len , extend_step_len ):
@@ -1531,18 +1535,19 @@ def __init__(self, cfg: Config, *, parent: Module):
15311535 qkv_linear = parent .qkv_linear
15321536 state = qkv_linear .initialize_parameters_recursively (jax .random .PRNGKey (0 ))
15331537
1534- # Check dtypes from init_states
1535- cache , _ = F (
1538+ # Check dtypes from init_states.
1539+ ( cache , init_output ) , _ = F (
15361540 qkv_linear ,
15371541 prng_key = jax .random .PRNGKey (0 ),
15381542 state = state ,
15391543 inputs = dict (
1540- target_batch_size = target_batch_size ,
1541- target_max_len = target_max_len ,
1544+ time_step = None ,
1545+ query = TensorSpec ([ target_batch_size , target_max_len ]) ,
15421546 ),
15431547 method = "init_states" ,
15441548 is_training = False ,
15451549 )
1550+ self .assertIsNone (init_output )
15461551 self .assertEqual (cache ["key" ].dtype , dtype )
15471552 self .assertEqual (cache ["value" ].dtype , dtype )
15481553
@@ -1562,7 +1567,7 @@ def __init__(self, cfg: Config, *, parent: Module):
15621567 prng_key = jax .random .PRNGKey (0 ),
15631568 state = state ,
15641569 inputs = dict (time_step = time_step , query = query ),
1565- method = "prefill_states " ,
1570+ method = "init_states " ,
15661571 is_training = False ,
15671572 )
15681573 self .assertEqual (init_state ["key" ].dtype , dtype )
@@ -2448,9 +2453,14 @@ def _test_extend_step(
24482453 inputs = inputs ,
24492454 )
24502455
2451- initial_state = layer .init_states (
2452- target_batch_size = batch_size , target_max_len = tgt_len , kv_state = kv_state
2456+ initial_state , initial_output = layer .init_states (
2457+ time_step = None ,
2458+ query = TensorSpec ([batch_size , tgt_len ]),
2459+ kv_state = kv_state ,
2460+ # This is unused for initializing state from scratch.
2461+ attention_logit_biases = None ,
24532462 )
2463+ self .assertIsNone (initial_output )
24542464 if kv_state is None :
24552465 for k in ["key" , "value" ]:
24562466 # Check that the cache dtype is inferred as the layer dtype.
@@ -2619,7 +2629,7 @@ def _test_prefill_states(
26192629 attention_logit_biases = attention_logit_biases ,
26202630 return_aux = return_aux ,
26212631 ),
2622- method = "prefill_states " ,
2632+ method = "init_states " ,
26232633 )
26242634
26252635 # Check time_step and shapes of state.
@@ -3227,6 +3237,96 @@ def test_multihead_attention_xl(self):
32273237 )
32283238
32293239
3240+ class TransformerAttentionLayerTest (TestCase ):
3241+ @parameterized .parameters ([False , True ])
3242+ def test_forward_vs_extend_step (self , with_source : bool ):
3243+ init_prng , target_prng , source_prng = jax .random .split (jax .random .PRNGKey (0 ), 3 )
3244+
3245+ model_dim = 8
3246+ layer_kwargs = dict (target_dim = model_dim , source_dim = model_dim )
3247+ cfg : TransformerAttentionLayer .Config = TransformerAttentionLayer .default_config ().set (
3248+ ** layer_kwargs
3249+ )
3250+ cfg .attention .set (num_heads = 2 , mask = causal_mask )
3251+ layer : TransformerAttentionLayer = cfg .set (name = "test" ).instantiate (parent = None )
3252+ layer_params = layer .initialize_parameters_recursively (prng_key = init_prng )
3253+
3254+ batch , decode_len = 2 , 6
3255+ target = jax .random .uniform (target_prng , shape = [batch , decode_len , model_dim ])
3256+ input_kwargs = {}
3257+
3258+ if with_source :
3259+ input_kwargs .update (
3260+ source = jax .random .uniform (source_prng , shape = [batch , decode_len , model_dim ])
3261+ )
3262+
3263+ forward_outputs , _ = F (
3264+ layer ,
3265+ inputs = dict (target = jnp .asarray (target ), ** input_kwargs ),
3266+ state = layer_params ,
3267+ is_training = True ,
3268+ prng_key = jax .random .PRNGKey (0 ),
3269+ )
3270+
3271+ for start_time_step in (- 1 , 0 , 2 , decode_len ):
3272+ if start_time_step < 0 :
3273+ (cached_states , init_outputs ), _ = F (
3274+ layer ,
3275+ inputs = dict (
3276+ time_step = None ,
3277+ target = TensorSpec (target .shape , target .dtype ),
3278+ ** input_kwargs ,
3279+ ),
3280+ state = layer_params ,
3281+ is_training = True ,
3282+ prng_key = jax .random .PRNGKey (0 ),
3283+ method = "init_states" ,
3284+ )
3285+ self .assertIsNone (init_outputs )
3286+ data = jnp .zeros ([batch , decode_len , model_dim ])
3287+ start_time_step = 0
3288+ else :
3289+ (cached_states , prefill_outputs ), _ = F (
3290+ layer ,
3291+ inputs = dict (
3292+ time_step = jnp .array ([start_time_step ] * batch , dtype = jnp .int32 ),
3293+ target = target ,
3294+ ** input_kwargs ,
3295+ ),
3296+ state = layer_params ,
3297+ is_training = True ,
3298+ prng_key = jax .random .PRNGKey (0 ),
3299+ method = "init_states" ,
3300+ )
3301+ data = prefill_outputs .data
3302+
3303+ data = jnp .einsum ("btd->tbd" , data )
3304+
3305+ for time_step in range (start_time_step , decode_len ):
3306+ extend_kwargs = {}
3307+ for k , v in input_kwargs .items ():
3308+ extend_kwargs [k ] = jnp .asarray (v [:, time_step : time_step + 1 , :])
3309+
3310+ (cached_states , extend_outputs ), _ = F (
3311+ layer ,
3312+ inputs = dict (
3313+ target = jnp .asarray (target [:, time_step : time_step + 1 , :]),
3314+ cached_states = cached_states ,
3315+ ** extend_kwargs ,
3316+ ),
3317+ state = layer_params ,
3318+ is_training = True ,
3319+ prng_key = jax .random .PRNGKey (0 ),
3320+ method = "extend_step" ,
3321+ )
3322+ data = data .at [time_step ].set (jnp .squeeze (extend_outputs .data , axis = 1 ))
3323+
3324+ data = jnp .einsum ("tbd->btd" , data )
3325+
3326+ # Prefill + extend_step == forward.
3327+ assert_allclose (forward_outputs .data , data )
3328+
3329+
32303330class TransformerFeedForwardLayerTest (TestCase ):
32313331 @parameterized .parameters (
32323332 dict (rms_norm_summary = []),
@@ -3392,20 +3492,21 @@ def _test_forward_vs_extend_step(
33923492 for start_time_step in (- 1 , 0 , 2 , tgt_len ):
33933493 if start_time_step > tgt_len :
33943494 continue
3395- print (f"start_time_step={ start_time_step } " )
3495+ print (f"start_time_step={ start_time_step } layer= { type ( layer ) } " )
33963496 if start_time_step < 0 :
3397- cached_states , _ = F (
3497+ ( cached_states , init_outputs ) , _ = F (
33983498 layer ,
33993499 inputs = dict (
3400- target_batch_size = batch_size ,
3401- target_max_len = tgt_len ,
3500+ time_step = None ,
3501+ data = TensorSpec ([ batch_size , tgt_len ]) ,
34023502 ** input_kwargs ,
34033503 ),
34043504 state = layer_params ,
34053505 is_training = True ,
34063506 prng_key = jax .random .PRNGKey (0 ),
34073507 method = "init_states" ,
34083508 )
3509+ self .assertIsNone (init_outputs )
34093510 decoder_output = jnp .zeros_like (target )
34103511 start_time_step = 0
34113512 else :
@@ -3419,7 +3520,7 @@ def _test_forward_vs_extend_step(
34193520 state = layer_params ,
34203521 is_training = True ,
34213522 prng_key = jax .random .PRNGKey (0 ),
3422- method = "prefill_states " ,
3523+ method = "init_states " ,
34233524 )
34243525 decoder_output = prefill_outputs .data
34253526 # Transpose to [tgt_len, batch_size, model_dim].
@@ -3850,7 +3951,7 @@ def test_transformer_extend_step(self, transformer_type, layer_type):
38503951 batch_size , src_len , tgt_len = 10 , 4 , 6
38513952 num_dec_layers , model_dim , num_heads = 3 , 16 , 4
38523953
3853- cfg = transformer_type .default_config ().set (
3954+ cfg : BaseStackedTransformerLayer . Config = transformer_type .default_config ().set (
38543955 name = "test" ,
38553956 input_dim = model_dim ,
38563957 num_layers = num_dec_layers ,
@@ -3872,7 +3973,7 @@ def test_transformer_extend_step(self, transformer_type, layer_type):
38723973 layer_cfg .feed_forward .hidden_dim = model_dim * 4
38733974
38743975 # Instantiate transformer stack.
3875- layer = cfg .instantiate (parent = None )
3976+ layer : BaseStackedTransformerLayer = cfg .instantiate (parent = None )
38763977 layer_params = layer .initialize_parameters_recursively (prng_key = jax .random .PRNGKey (123 ))
38773978
38783979 target = jax .random .normal (jax .random .PRNGKey (123 ), [batch_size , tgt_len , model_dim ])
@@ -3897,7 +3998,11 @@ def test_transformer_extend_step(self, transformer_type, layer_type):
38973998 is_training = False ,
38983999 prng_key = jax .random .PRNGKey (0 ),
38994000 )
3900- initial_state = layer .init_states (target_batch_size = batch_size , target_max_len = tgt_len )
4001+ initial_state , initial_output = layer .init_states (
4002+ time_step = None ,
4003+ data = TensorSpec ([batch_size , tgt_len ]),
4004+ )
4005+ self .assertIsNone (initial_output )
39014006 inputs = dict (
39024007 cached_states = initial_state , cross_attention_data = source , return_aux = return_aux
39034008 )
@@ -4036,7 +4141,7 @@ def test_transformer_prefill_states(self, transformer_type, layer_type):
40364141 cross_attention_logit_biases = cross_attention_logit_biases ,
40374142 return_aux = return_aux ,
40384143 ),
4039- method = "prefill_states " ,
4144+ method = "init_states " ,
40404145 )
40414146
40424147 # Zero-out outputs starting from initial time_step, and test that we can recover the full
0 commit comments