Skip to content

Commit ae855ed

Browse files
author
Mark Lee
authored
Ensures that cache_dtype is respected. (#977)
1 parent cfef38b commit ae855ed

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

axlearn/common/attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -853,14 +853,15 @@ def extend_step(
853853

854854
# Create a dispatch matrix of shape [B, T=step, S].
855855
oh_indices = jax.nn.one_hot(
856-
time_step[:, None] + jnp.arange(num_query_steps), source_len, dtype=k_proj.dtype
856+
time_step[:, None] + jnp.arange(num_query_steps), source_len, dtype=cached_key.dtype
857857
)
858858
# Create a mask of shape [B, S, 1, 1].
859859
negated_oh_indices = (1 - oh_indices.sum(axis=1))[..., None, None]
860860
k_proj = jnp.einsum("bt...,bts->bs...", k_proj, oh_indices)
861861
v_proj = jnp.einsum("bt...,bts->bs...", v_proj, oh_indices)
862-
k_proj = cached_key * negated_oh_indices + k_proj
863-
v_proj = cached_value * negated_oh_indices + v_proj
862+
# Ensure that we accumulate using the original dtype.
863+
k_proj = cached_key * negated_oh_indices + k_proj.astype(cached_key.dtype)
864+
v_proj = cached_value * negated_oh_indices + v_proj.astype(cached_value.dtype)
864865

865866
updated_state.update(key=k_proj, value=v_proj)
866867
return updated_state, self.Output(query=q_proj, key=k_proj, value=v_proj)
@@ -1750,8 +1751,7 @@ def _forward_for_mode(
17501751
# Validate key & value combination.
17511752
if (key is None) != (value is None):
17521753
raise ValueError(
1753-
"key and value must be both None or both set, "
1754-
f"key:{type(key)}, value:{type(value)}"
1754+
f"key and value must be both None or both set, key:{type(key)}, value:{type(value)}"
17551755
)
17561756
if kv_state is not None:
17571757
if key is not None or value is not None:

axlearn/common/attention_test.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
TensorSpec,
123123
VDict,
124124
as_tensor,
125+
cast_floats,
125126
flatten_items,
126127
save_and_offload_only_these_names_regex,
127128
shapes,
@@ -1562,19 +1563,27 @@ def test_qlinear(self, base_cfg, test_cfg):
15621563
# Check that the outputs are close for all pairs.
15631564
self.assertNestedAllClose(outputs[layer_a], outputs[layer_b])
15641565

1565-
@parameterized.parameters(
1566-
(attention.QKVLinear, 1),
1567-
(attention.FusedQKVLinear, 1),
1568-
(attention.GroupedQKVLinear, 1),
1569-
(attention.FusedGroupedQKVLinear, 1),
1570-
(attention.RoFormerQKVLinear, 1),
1571-
(attention.QKVLinear, 2),
1572-
(attention.FusedQKVLinear, 3),
1573-
(attention.GroupedQKVLinear, 4),
1574-
(attention.FusedGroupedQKVLinear, 3),
1575-
(attention.RoFormerQKVLinear, 2),
1566+
@parameterized.product(
1567+
[
1568+
dict(layer_cls=attention.QKVLinear, extend_step_len=1),
1569+
dict(layer_cls=attention.FusedQKVLinear, extend_step_len=1),
1570+
dict(layer_cls=attention.GroupedQKVLinear, extend_step_len=1),
1571+
dict(layer_cls=attention.FusedGroupedQKVLinear, extend_step_len=1),
1572+
dict(layer_cls=attention.RoFormerQKVLinear, extend_step_len=1),
1573+
dict(layer_cls=attention.QKVLinear, extend_step_len=2),
1574+
dict(layer_cls=attention.FusedQKVLinear, extend_step_len=3),
1575+
dict(layer_cls=attention.GroupedQKVLinear, extend_step_len=4),
1576+
dict(layer_cls=attention.FusedGroupedQKVLinear, extend_step_len=3),
1577+
dict(layer_cls=attention.RoFormerQKVLinear, extend_step_len=2),
1578+
],
1579+
cache_dtype=[None, jnp.bfloat16],
15761580
)
1577-
def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], extend_step_len):
1581+
def test_repeated_extend_step(
1582+
self,
1583+
layer_cls: type[attention.BaseQKVLinear],
1584+
extend_step_len: int,
1585+
cache_dtype: Optional[jnp.dtype],
1586+
):
15781587
"""Tests that calling QKVLinear.extend_step() multiple times with the
15791588
same time_step results in the same output."""
15801589
model_dim = 8
@@ -1586,10 +1595,12 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], ex
15861595
value_dim=model_dim,
15871596
num_heads=num_heads,
15881597
per_head_dim=per_head_dim,
1598+
cache_dtype=cache_dtype,
15891599
)
15901600
cfg = layer_cls.default_config().set(**layer_kwargs)
15911601
maybe_set_config(cfg, num_kv_heads=num_heads, rotary_value=False)
15921602
layer = cfg.set(name="test").instantiate(parent=None)
1603+
expect_dtype = cache_dtype or layer.dtype()
15931604

15941605
# Construct base layer state.
15951606
layer_state = layer.initialize_parameters_recursively(jax.random.PRNGKey(0))
@@ -1609,6 +1620,8 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], ex
16091620
cache_state, init_output = layer.init_states(
16101621
time_step=None, query=TensorSpec([batch_size, tgt_len])
16111622
)
1623+
self.assertEqual(cache_state["key"].dtype, expect_dtype)
1624+
self.assertEqual(cache_state["value"].dtype, expect_dtype)
16121625
self.assertIsNone(init_output)
16131626
step_querys = []
16141627
step_keys = step_values = None
@@ -1624,10 +1637,12 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], ex
16241637
step_querys.append(step_output.query)
16251638
step_keys = step_output.key
16261639
step_values = step_output.value
1640+
self.assertEqual(cache_state["key"].dtype, expect_dtype)
1641+
self.assertEqual(cache_state["value"].dtype, expect_dtype)
16271642

16281643
self.assertNestedAllClose(fwd_output.query, jnp.concat(step_querys, axis=1))
1629-
self.assertNestedAllClose(fwd_output.key, step_keys)
1630-
self.assertNestedAllClose(fwd_output.value, step_values)
1644+
self.assertNestedAllClose(cast_floats(fwd_output.key, cache_dtype), step_keys)
1645+
self.assertNestedAllClose(cast_floats(fwd_output.value, cache_dtype), step_values)
16311646

16321647
@parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16)
16331648
def test_dtypes_inherited_from_parent(self, dtype: jnp.dtype):

0 commit comments

Comments
 (0)