122
122
TensorSpec ,
123
123
VDict ,
124
124
as_tensor ,
125
+ cast_floats ,
125
126
flatten_items ,
126
127
save_and_offload_only_these_names_regex ,
127
128
shapes ,
@@ -1562,19 +1563,27 @@ def test_qlinear(self, base_cfg, test_cfg):
1562
1563
# Check that the outputs are close for all pairs.
1563
1564
self .assertNestedAllClose (outputs [layer_a ], outputs [layer_b ])
1564
1565
1565
- @parameterized .parameters (
1566
- (attention .QKVLinear , 1 ),
1567
- (attention .FusedQKVLinear , 1 ),
1568
- (attention .GroupedQKVLinear , 1 ),
1569
- (attention .FusedGroupedQKVLinear , 1 ),
1570
- (attention .RoFormerQKVLinear , 1 ),
1571
- (attention .QKVLinear , 2 ),
1572
- (attention .FusedQKVLinear , 3 ),
1573
- (attention .GroupedQKVLinear , 4 ),
1574
- (attention .FusedGroupedQKVLinear , 3 ),
1575
- (attention .RoFormerQKVLinear , 2 ),
1566
+ @parameterized .product (
1567
+ [
1568
+ dict (layer_cls = attention .QKVLinear , extend_step_len = 1 ),
1569
+ dict (layer_cls = attention .FusedQKVLinear , extend_step_len = 1 ),
1570
+ dict (layer_cls = attention .GroupedQKVLinear , extend_step_len = 1 ),
1571
+ dict (layer_cls = attention .FusedGroupedQKVLinear , extend_step_len = 1 ),
1572
+ dict (layer_cls = attention .RoFormerQKVLinear , extend_step_len = 1 ),
1573
+ dict (layer_cls = attention .QKVLinear , extend_step_len = 2 ),
1574
+ dict (layer_cls = attention .FusedQKVLinear , extend_step_len = 3 ),
1575
+ dict (layer_cls = attention .GroupedQKVLinear , extend_step_len = 4 ),
1576
+ dict (layer_cls = attention .FusedGroupedQKVLinear , extend_step_len = 3 ),
1577
+ dict (layer_cls = attention .RoFormerQKVLinear , extend_step_len = 2 ),
1578
+ ],
1579
+ cache_dtype = [None , jnp .bfloat16 ],
1576
1580
)
1577
- def test_repeated_extend_step (self , layer_cls : type [attention .BaseQKVLinear ], extend_step_len ):
1581
+ def test_repeated_extend_step (
1582
+ self ,
1583
+ layer_cls : type [attention .BaseQKVLinear ],
1584
+ extend_step_len : int ,
1585
+ cache_dtype : Optional [jnp .dtype ],
1586
+ ):
1578
1587
"""Tests that calling QKVLinear.extend_step() multiple times with the
1579
1588
same time_step results in the same output."""
1580
1589
model_dim = 8
@@ -1586,10 +1595,12 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], ex
1586
1595
value_dim = model_dim ,
1587
1596
num_heads = num_heads ,
1588
1597
per_head_dim = per_head_dim ,
1598
+ cache_dtype = cache_dtype ,
1589
1599
)
1590
1600
cfg = layer_cls .default_config ().set (** layer_kwargs )
1591
1601
maybe_set_config (cfg , num_kv_heads = num_heads , rotary_value = False )
1592
1602
layer = cfg .set (name = "test" ).instantiate (parent = None )
1603
+ expect_dtype = cache_dtype or layer .dtype ()
1593
1604
1594
1605
# Construct base layer state.
1595
1606
layer_state = layer .initialize_parameters_recursively (jax .random .PRNGKey (0 ))
@@ -1609,6 +1620,8 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], ex
1609
1620
cache_state , init_output = layer .init_states (
1610
1621
time_step = None , query = TensorSpec ([batch_size , tgt_len ])
1611
1622
)
1623
+ self .assertEqual (cache_state ["key" ].dtype , expect_dtype )
1624
+ self .assertEqual (cache_state ["value" ].dtype , expect_dtype )
1612
1625
self .assertIsNone (init_output )
1613
1626
step_querys = []
1614
1627
step_keys = step_values = None
@@ -1624,10 +1637,12 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], ex
1624
1637
step_querys .append (step_output .query )
1625
1638
step_keys = step_output .key
1626
1639
step_values = step_output .value
1640
+ self .assertEqual (cache_state ["key" ].dtype , expect_dtype )
1641
+ self .assertEqual (cache_state ["value" ].dtype , expect_dtype )
1627
1642
1628
1643
self .assertNestedAllClose (fwd_output .query , jnp .concat (step_querys , axis = 1 ))
1629
- self .assertNestedAllClose (fwd_output .key , step_keys )
1630
- self .assertNestedAllClose (fwd_output .value , step_values )
1644
+ self .assertNestedAllClose (cast_floats ( fwd_output .key , cache_dtype ) , step_keys )
1645
+ self .assertNestedAllClose (cast_floats ( fwd_output .value , cache_dtype ) , step_values )
1631
1646
1632
1647
@parameterized .parameters (jnp .float32 , jnp .float16 , jnp .bfloat16 )
1633
1648
def test_dtypes_inherited_from_parent (self , dtype : jnp .dtype ):
0 commit comments