Skip to content

Qwen2 KV-cache updates do not cast projections to cache dtype before dynamic_update_slice #1321

@skwh54

Description

@skwh54

Expected Behavior

Qwen2 cached decoding should update KV cache tensors using values whose dtype matches the cache dtype.

Actual Behavior

In tunix/models/qwen2/model.py, the Qwen2 cached decoding path writes value_proj and key_proj into cache['v'] / cache['k'] via jax.lax.dynamic_update_slice(...) without first casting them to the cache dtype:

value_proj = jax.lax.dynamic_update_slice(
    cache['v'],
    value_proj,
    slice_indices,
)
key_proj = jax.lax.dynamic_update_slice(
    cache['k'], key_proj, slice_indices
)

I reproduced this cache-update failure in the Qwen2 attention cached-decoding path with jax==0.8.1 and flax==0.12.1 after locally applying only the RNG compatibility change so execution can reach the cache update line:

PYTHONPATH=$PWD python - <<'PY'
from flax import nnx
import jax.numpy as jnp
from tunix.models.qwen2.model import Attention, ModelConfig, ShardingConfig

cfg = ModelConfig(
    num_layers=1,
    vocab_size=32,
    embed_dim=8,
    hidden_dim=16,
    num_heads=2,
    head_dim=4,
    num_kv_heads=1,
    rope_theta=10000,
    norm_eps=1e-6,
    shd_config=ShardingConfig.get_default_sharding(),
    dtype=jnp.float32,
    param_dtype=jnp.float32,
    use_flash_attention=False,
)

attn = Attention(cfg, rngs=nnx.Rngs(0))
x = jnp.zeros((1, 1, 8), dtype=jnp.float32)
sin = jnp.zeros((1, 1, 2), dtype=jnp.float32)
cos = jnp.ones((1, 1, 2), dtype=jnp.float32)
cache = {
    'end_index': jnp.array([0], dtype=jnp.int32),
    'v': jnp.zeros((1, 4, 1, 4), dtype=jnp.bfloat16),
    'k': jnp.zeros((1, 4, 1, 4), dtype=jnp.bfloat16),
}

attn.block(x, cache, None, sin, cos)
PY

Observed output:

Traceback (most recent call last):
  File "<stdin>", line 31, in <module>
  File ".../tunix/models/qwen2/model.py", line 476, in block
    value_proj = jax.lax.dynamic_update_slice(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got bfloat16, float32.

Steps to Reproduce the Problem

  1. Create an environment with jax==0.8.1 and flax==0.12.1.
  2. Run the Qwen2 Attention.block(...) cached-decoding path with a cache whose dtype differs from the projection dtype.
  3. Observe TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got bfloat16, float32. from the cache update line in tunix/models/qwen2/model.py.

Environment

  • OS: Ubuntu 22.04.5 LTS
  • Package version: 0.1.7
  • Python: 3.12.13
  • JAX: 0.8.1
  • jaxlib: 0.8.1
  • Flax: 0.12.1

Checklist

  • I have searched the existing issues for a similar bug report.
  • I have provided all the required information in the "Environment" section.
  • I have provided a minimal, reproducible example.

Would you like to help us fix it?

Yes.

Metadata

Metadata

Assignees

Labels

type:bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions