Expected Behavior
Qwen2 cached decoding should update KV cache tensors using values whose dtype matches the cache dtype.
Actual Behavior
In tunix/models/qwen2/model.py, the Qwen2 cached decoding path writes value_proj and key_proj into cache['v'] / cache['k'] via jax.lax.dynamic_update_slice(...) without first casting them to the cache dtype:
value_proj = jax.lax.dynamic_update_slice(
cache['v'],
value_proj,
slice_indices,
)
key_proj = jax.lax.dynamic_update_slice(
cache['k'], key_proj, slice_indices
)
I reproduced this cache-update failure in the Qwen2 attention cached-decoding path with jax==0.8.1 and flax==0.12.1 after locally applying only the RNG compatibility change so execution can reach the cache update line:
PYTHONPATH=$PWD python - <<'PY'
from flax import nnx
import jax.numpy as jnp
from tunix.models.qwen2.model import Attention, ModelConfig, ShardingConfig
cfg = ModelConfig(
num_layers=1,
vocab_size=32,
embed_dim=8,
hidden_dim=16,
num_heads=2,
head_dim=4,
num_kv_heads=1,
rope_theta=10000,
norm_eps=1e-6,
shd_config=ShardingConfig.get_default_sharding(),
dtype=jnp.float32,
param_dtype=jnp.float32,
use_flash_attention=False,
)
attn = Attention(cfg, rngs=nnx.Rngs(0))
x = jnp.zeros((1, 1, 8), dtype=jnp.float32)
sin = jnp.zeros((1, 1, 2), dtype=jnp.float32)
cos = jnp.ones((1, 1, 2), dtype=jnp.float32)
cache = {
'end_index': jnp.array([0], dtype=jnp.int32),
'v': jnp.zeros((1, 4, 1, 4), dtype=jnp.bfloat16),
'k': jnp.zeros((1, 4, 1, 4), dtype=jnp.bfloat16),
}
attn.block(x, cache, None, sin, cos)
PY
Observed output:
Traceback (most recent call last):
File "<stdin>", line 31, in <module>
File ".../tunix/models/qwen2/model.py", line 476, in block
value_proj = jax.lax.dynamic_update_slice(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got bfloat16, float32.
Steps to Reproduce the Problem
- Create an environment with
jax==0.8.1 and flax==0.12.1.
- Run the Qwen2
Attention.block(...) cached-decoding path with a cache whose dtype differs from the projection dtype.
- Observe
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got bfloat16, float32. from the cache update line in tunix/models/qwen2/model.py.
Environment
- OS: Ubuntu 22.04.5 LTS
- Package version: 0.1.7
- Python: 3.12.13
- JAX: 0.8.1
- jaxlib: 0.8.1
- Flax: 0.12.1
Checklist
Would you like to help us fix it?
Yes.
Expected Behavior
Qwen2 cached decoding should update KV cache tensors using values whose dtype matches the cache dtype.
Actual Behavior
In
tunix/models/qwen2/model.py, the Qwen2 cached decoding path writesvalue_projandkey_projintocache['v']/cache['k']viajax.lax.dynamic_update_slice(...)without first casting them to the cache dtype:I reproduced this cache-update failure in the Qwen2 attention cached-decoding path with
jax==0.8.1andflax==0.12.1after locally applying only the RNG compatibility change so execution can reach the cache update line:Observed output:
Steps to Reproduce the Problem
jax==0.8.1andflax==0.12.1.Attention.block(...)cached-decoding path with a cache whose dtype differs from the projection dtype.TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got bfloat16, float32.from the cache update line intunix/models/qwen2/model.py.Environment
Checklist
Would you like to help us fix it?
Yes.