Skip to content

The compute_attn_1rowblock_splitkv function does not work properly #47

@LoserCheems

Description

@LoserCheems

There are some issues with the data loading and offset calculation in the compute_attn_1rowblock_splitkv function, which leads to the split_kv branch not being equivalent to the calculation structure implemented in Python.

✅ Successfully imported flash_dma_cuda
🧬==============================================================================🧬
🔬 Dynamic Mask Attention Forward Pass Equivalence Test Suite 🔬
🧬==============================================================================🧬
🐍 PyTorch version: 2.8.0a0+5228986c39.nv25.05
🔥 Device: cuda
🎮 CUDA device: NVIDIA GeForce RTX 4090
🎲 Random seed: 42
📊 Test type: all
🎯 Accuracy threshold: 95.0%

📍 Starting Standard Forward Pass Tests 📍

🚀============================================================================🚀
🔬 Testing Forward Pass Equivalence: Python Prototype vs CUDA Implementation 🔬
🚀============================================================================🚀
🔥 Using device: cuda

🧪 Test configuration 1/1 [█]
  📊 batch_size=1, num_heads=2, num_kv_heads=1
  📏 query_len=512, key_len=512, head_dim=128
  🔒 is_causal=False
  🎯 Accuracy threshold: 95.0%
Split = 1, Append_KV = 0, Is_causal = 0, IsEvenMNConst = 1, IsEvenKConst = 1, Is_softcap = 0
smem_size = 65536, CTAs per SM = 1
Before copy_ZOH: n_block=3, seqlen_q_offset=192, seqlen_k_offset=320
Before copy_ZOH: n_block=1, seqlen_q_offset=384, seqlen_k_offset=448
Before copy_ZOH: n_block=3, seqlen_q_offset=128, seqlen_k_offset=320
Before copy_ZOH: n_block=7, seqlen_q_offset=128, seqlen_k_offset=64
Before copy_ZOH: n_block=1, seqlen_q_offset=320, seqlen_k_offset=448
Before copy_ZOH: n_block=5, seqlen_q_offset=320, seqlen_k_offset=192
Before copy_ZOH: n_block=7, seqlen_q_offset=448, seqlen_k_offset=64
Before copy_ZOH: n_block=5, seqlen_q_offset=256, seqlen_k_offset=192
Before copy_ZOH: n_block=3, seqlen_q_offset=448, seqlen_k_offset=320
Before copy_ZOH: n_block=7, seqlen_q_offset=64, seqlen_k_offset=64
Before copy_ZOH: n_block=5, seqlen_q_offset=64, seqlen_k_offset=192
Before copy_ZOH: n_block=3, seqlen_q_offset=256, seqlen_k_offset=320
Before copy_ZOH: n_block=1, seqlen_q_offset=448, seqlen_k_offset=448
Before copy_ZOH: n_block=7, seqlen_q_offset=384, seqlen_k_offset=64
Before copy_ZOH: n_block=3, seqlen_q_offset=64, seqlen_k_offset=320
Before copy_ZOH: n_block=1, seqlen_q_offset=256, seqlen_k_offset=448
Before copy_ZOH: n_block=7, seqlen_q_offset=192, seqlen_k_offset=64
Before copy_ZOH: n_block=5, seqlen_q_offset=384, seqlen_k_offset=192
Before copy_ZOH: n_block=1, seqlen_q_offset=64, seqlen_k_offset=448
SplitKV: m_block=0, n_block_max=8, n_block_min=6, n_split_idx=3
Before copy_ZOH: n_block=7, seqlen_q_offset=320, seqlen_k_offset=64
Before copy_ZOH: n_block=1, seqlen_q_offset=192, seqlen_k_offset=448
Before copy_ZOH: n_block=5, seqlen_q_offset=128, seqlen_k_offset=192
Before copy_ZOH: n_block=3, seqlen_q_offset=320, seqlen_k_offset=320
Before copy_ZOH: n_block=5, seqlen_q_offset=192, seqlen_k_offset=192
Before copy_ZOH: n_block=3, seqlen_q_offset=384, seqlen_k_offset=320
Before copy_ZOH: n_block=7, seqlen_q_offset=256, seqlen_k_offset=64
Before copy_ZOH: n_block=5, seqlen_q_offset=448, seqlen_k_offset=192
Before copy_ZOH: n_block=1, seqlen_q_offset=128, seqlen_k_offset=448
SplitKV: m_block=0, n_block_max=6, n_block_min=4, n_split_idx=2
SplitKV: m_block=0, n_block_max=4, n_block_min=2, n_split_idx=1
SplitKV: m_block=0, n_block_max=2, n_block_min=0, n_split_idx=0
col_offset_zoh=64, row_offset=0, zoh_ptr=0xb07000000, final_ptr=0xb07000080
col_offset_zoh=448, row_offset=0, zoh_ptr=0xb07000000, final_ptr=0xb07000380
col_offset_zoh=320, row_offset=0, zoh_ptr=0xb07000000, final_ptr=0xb07000280
col_offset_zoh=192, row_offset=0, zoh_ptr=0xb07000000, final_ptr=0xb07000180
Before copy_ZOH: n_block=1, seqlen_q_offset=512, seqlen_k_offset=448
Before copy_ZOH: n_block=7, seqlen_q_offset=512, seqlen_k_offset=64
Before copy_ZOH: n_block=3, seqlen_q_offset=512, seqlen_k_offset=320
Before copy_ZOH: n_block=5, seqlen_q_offset=512, seqlen_k_offset=192
📋 Original result: torch.Size([1, 512, 2, 128]), torch.bfloat16
⚡ CUDA result: torch.Size([1, 512, 2, 128]), torch.bfloat16

🔍 Debugging info:
  📈 Original result range: [-1.015625, 1.101562]
  ⚡ CUDA result range: [-0.832031, 0.980469]
  ✅ Original result contains NaN: False, Inf: False
  ✅ CUDA result contains NaN: False, Inf: False

📊 Result analysis:
  📌 Maximum absolute difference: 0.64843750
  📍 Mean absolute difference: 0.03759766
  📍 Position of maximum difference: (tensor(0, device='cuda:0'), tensor(507, device='cuda:0'), tensor(0, device='cuda:0'), tensor(71, device='cuda:0'))
  📋 Original value at position: -0.18750000
  ⚡ CUDA value at position: 0.46093750
  📏 Maximum relative difference: 25344.00000000
  📏 Mean relative difference: 1.67968750
  ⚠️ Elements within tolerance ratio: 0.3482 (45641/131072)
  ❌ Accuracy threshold (95.0%): Fail
  ❌ Strict allclose (bfloat16 tolerance: rtol=0.01, atol=0.01): No

⚡ Performance comparison:
    🐍 Python implementation: 223.84 ms
    🚀 CUDA implementation:   1.97 ms
    📈 Speedup:               113.81x

❌ Test result: Failed
  ⚠️ Difference too large, stopping subsequent tests.

🏁============================================================================🏁
😞 Forward Equivalence Test Summary: Some Tests Failed
🏁============================================================================🏁

🏆==============================================================================🏆
🔬 FINAL TEST SUMMARY 🔬
🏆==============================================================================🏆
  ❌ FWD          : FAILED

😞 OVERALL RESULT: SOME TESTS FAILED
🏆==============================================================================🏆

Metadata

Metadata

Labels

bugSomething isn't working

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions