@@ -57,7 +57,7 @@ def generate_attention_metadata(num_tokens, mesh) -> AttentionMetadata:
5757
5858def generate_kv_caches (num_kv_heads , head_size , mesh , dtype ):
5959 cache_shape = get_kv_cache_shape_with_mesh (mesh , 1024 , 16 , num_kv_heads ,
60- head_size , dtype )
60+ head_size , t2j_dtype ( dtype ) )
6161 sharding = NamedSharding (mesh , PartitionSpec ())
6262
6363 def _allocate ():
@@ -138,15 +138,16 @@ def test_jax_attention(mesh, num_heads, head_size, num_kv_heads, num_tokens):
138138 vllm_model_wrapper_context = get_vllm_model_wrapper_context ()
139139 kv_cache = vllm_model_wrapper_context .kv_caches [0 ]
140140
141- ref_output = ref_ragged_paged_attention (q ,
142- k ,
143- v ,
144- kv_cache ,
145- md .seq_lens ,
146- md .block_tables ,
147- md .query_start_loc ,
148- md .request_distribution ,
149- sm_scale = scale )
141+ ref_output , _ = ref_ragged_paged_attention (
142+ q ,
143+ jax .device_put (t2j (k ), NamedSharding (mesh , P ())),
144+ jax .device_put (t2j (v ), NamedSharding (mesh , P ())),
145+ kv_cache ,
146+ md .seq_lens ,
147+ md .block_tables ,
148+ md .query_start_loc ,
149+ md .request_distribution ,
150+ sm_scale = scale )
150151 ref_output = j2t (ref_output .astype (jnp .float32 )).to (dtype )
151152
152153 torch .testing .assert_close (ref_output , jax_output , atol = 1e-2 , rtol = 1e-5 )
0 commit comments