Skip to content

Commit 1b9dace

Browse files
authored
Merge pull request #200 from SmallDoges/optim-triton-version
Optimize triton version: GQA, mask/bias broadcasting, skip inactive tiles, and stability fixes
2 parents e3bcf48 + b7deeba commit 1b9dace

File tree

2 files changed

+871
-506
lines changed

2 files changed

+871
-506
lines changed

benchmarks/backward_equivalence.py

Lines changed: 239 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,9 @@ def dynamic_mask_attention_triton(
283283
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)
284284

285285
# Ensure correct data types and memory layout for Triton function
286-
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
287-
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim]
288-
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim]
289-
attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
290-
attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
286+
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
287+
key_states = key_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim]
288+
value_states = value_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim]
291289

292290
# Call the Triton implementation
293291
attn_outputs = triton_dmattn_func(
@@ -729,6 +727,239 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
729727
return all_passed
730728

731729

730+
def test_triton_backward_equivalence(accuracy_threshold=0.95):
731+
"""Test backward pass equivalence between Python prototype and Triton implementation."""
732+
print("\n" + "🚀" + "=" * 76 + "🚀")
733+
print("🔬 Testing backward Pass Equivalence: Python Prototype vs Triton Implementation")
734+
print("🚀" + "=" * 76 + "🚀")
735+
736+
# Check if Triton implementation is available
737+
if triton_dmattn_func is None:
738+
print("❌ Triton implementation not available, skipping test.")
739+
return False
740+
741+
# Set random seed for reproducibility
742+
torch.manual_seed(0)
743+
744+
# Test different parameter configurations
745+
# If you encounter NAN issues when running multiple configurations, try running a single configuration
746+
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
747+
test_configs = [
748+
# Head dim 32
749+
(1, 2, 1, 128, 128, 32, False),
750+
(1, 2, 1, 128, 128, 32, True),
751+
(1, 2, 1, 256, 256, 32, False),
752+
(1, 2, 1, 256, 256, 32, True),
753+
(1, 2, 1, 512, 512, 32, False),
754+
(1, 2, 1, 512, 512, 32, True),
755+
(1, 2, 1, 1024, 1024, 32, False),
756+
(1, 2, 1, 1024, 1024, 32, True),
757+
(1, 2, 1, 2048, 2048, 32, False),
758+
(1, 2, 1, 2048, 2048, 32, True),
759+
(1, 2, 1, 4096, 4096, 32, False),
760+
(1, 2, 1, 4096, 4096, 32, True),
761+
762+
# Head dim 64
763+
(1, 2, 1, 128, 128, 64, False),
764+
(1, 2, 1, 128, 128, 64, True),
765+
(1, 2, 1, 256, 256, 64, False),
766+
(1, 2, 1, 256, 256, 64, True),
767+
(1, 2, 1, 512, 512, 64, False),
768+
(1, 2, 1, 512, 512, 64, True),
769+
(1, 2, 1, 1024, 1024, 64, False),
770+
(1, 2, 1, 1024, 1024, 64, True),
771+
(1, 2, 1, 2048, 2048, 64, False),
772+
(1, 2, 1, 2048, 2048, 64, True),
773+
(1, 2, 1, 4096, 4096, 64, False),
774+
(1, 2, 1, 4096, 4096, 64, True),
775+
776+
# Head dim 96
777+
(1, 2, 1, 128, 128, 96, False),
778+
(1, 2, 1, 128, 128, 96, True),
779+
(1, 2, 1, 256, 256, 96, False),
780+
(1, 2, 1, 256, 256, 96, True),
781+
(1, 2, 1, 512, 512, 96, False),
782+
(1, 2, 1, 512, 512, 96, True),
783+
(1, 2, 1, 1024, 1024, 96, False),
784+
(1, 2, 1, 1024, 1024, 96, True),
785+
(1, 2, 1, 2048, 2048, 96, False),
786+
(1, 2, 1, 2048, 2048, 96, True),
787+
(1, 2, 1, 4096, 4096, 96, False),
788+
(1, 2, 1, 4096, 4096, 96, True),
789+
790+
# Head dim 128
791+
(1, 2, 1, 128, 128, 128, False),
792+
(1, 2, 1, 128, 128, 128, True),
793+
(1, 2, 1, 256, 256, 128, False),
794+
(1, 2, 1, 256, 256, 128, True),
795+
(1, 2, 1, 512, 512, 128, False),
796+
(1, 2, 1, 512, 512, 128, True),
797+
(1, 2, 1, 1024, 1024, 128, False),
798+
(1, 2, 1, 1024, 1024, 128, True),
799+
(1, 2, 1, 2048, 2048, 128, False),
800+
(1, 2, 1, 2048, 2048, 128, True),
801+
(1, 2, 1, 4096, 4096, 128, False),
802+
(1, 2, 1, 4096, 4096, 128, True),
803+
804+
# triton currently supports up to head dim 128
805+
]
806+
807+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
808+
dtype = torch.bfloat16
809+
device_icon = "🔥" if device.type == "cuda" else "💻"
810+
print(f"{device_icon} Using device: {device}")
811+
812+
all_passed = True
813+
814+
for i, config in enumerate(test_configs):
815+
torch.cuda.empty_cache()
816+
gc.collect()
817+
torch.cuda.synchronize()
818+
819+
batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal = config
820+
821+
# Progress indicator
822+
progress_filled = "█" * (i + 1)
823+
progress_empty = "░" * (len(test_configs) - i - 1)
824+
progress_bar = f"[{progress_filled}{progress_empty}]"
825+
826+
print(f"\n🧪 Test configuration {i+1}/{len(test_configs)} {progress_bar}")
827+
print(f" 📊 batch_size={batch_size}, num_heads={num_heads}, num_kv_heads={num_kv_heads}")
828+
print(f" 📏 query_len={query_len}, key_len={key_len}, head_dim={head_dim}")
829+
print(f" 🔒 is_causal={is_causal}")
830+
print(f" 🎯 Accuracy threshold: {accuracy_threshold*100:.1f}%")
831+
832+
# Create random input data
833+
query_states = torch.randn(
834+
batch_size, num_heads, query_len, head_dim,
835+
device=device, dtype=dtype, requires_grad=True
836+
)
837+
key_states = torch.randn(
838+
batch_size, num_kv_heads, key_len, head_dim,
839+
device=device, dtype=dtype, requires_grad=True
840+
)
841+
value_states = torch.randn(
842+
batch_size, num_kv_heads, key_len, head_dim,
843+
device=device, dtype=dtype, requires_grad=True
844+
)
845+
attn_bias = torch.randn(
846+
batch_size, num_kv_heads, query_len, key_len,
847+
device=device, dtype=torch.bfloat16
848+
)
849+
cache_position = torch.arange(key_len - query_len, key_len, device=device)
850+
causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1)
851+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
852+
853+
# Set scaling factor and keep window size
854+
scaling = head_dim ** -0.5
855+
window_size = 10240
856+
857+
# Clone inputs for Python implementation
858+
query_python = query_states.clone().detach().requires_grad_(True)
859+
key_python = key_states.clone().detach().requires_grad_(True)
860+
value_python = value_states.clone().detach().requires_grad_(True)
861+
attn_bias_python = attn_bias.clone().detach().requires_grad_(True)
862+
causal_mask_python = causal_mask.clone().detach()
863+
864+
# Run Python implementation
865+
start_time = time.time()
866+
attn_outputs_python, dq_python, dk_python, dv_python, dbias_python = dynamic_mask_attention_python(
867+
query_python, key_python, value_python,
868+
attn_bias_python, causal_mask_python,
869+
scaling, window_size, is_causal
870+
)
871+
torch.cuda.synchronize()
872+
py_time = time.time() - start_time
873+
874+
# Clone inputs for Triton implementation
875+
query_triton = query_states.clone().detach().requires_grad_(True)
876+
key_triton = key_states.clone().detach().requires_grad_(True)
877+
value_triton = value_states.clone().detach().requires_grad_(True)
878+
attn_bias_triton = attn_bias.clone().detach().requires_grad_(True)
879+
causal_mask_triton = causal_mask.clone().detach()
880+
881+
# Run Triton implementation
882+
start_time = time.time()
883+
attn_outputs_triton, dq_triton, dk_triton, dv_triton, dbias_triton = dynamic_mask_attention_triton(
884+
query_triton, key_triton, value_triton,
885+
attn_bias_triton, causal_mask_triton,
886+
scaling, window_size, is_causal
887+
)
888+
torch.cuda.synchronize()
889+
triton_time = time.time() - start_time
890+
891+
# Analyze outputs
892+
print(f"\n🔍 Analyzing differences between Python and Triton outputs:")
893+
is_attn_output_close, max_attn_output_diff, mean_attn_output_diff = analyze_differences(
894+
attn_outputs_python, attn_outputs_triton, accuracy_threshold
895+
)
896+
897+
# Analyze dQ gradients
898+
print(f"\n🔍 Analyzing dQ gradients:")
899+
is_dq_close, max_dq_diff, mean_dq_diff = analyze_differences(
900+
dq_python, dq_triton, accuracy_threshold
901+
)
902+
903+
# Analyze dK gradients
904+
print(f"\n🔍 Analyzing dK gradients:")
905+
is_dk_close, max_dk_diff, mean_dk_diff = analyze_differences(
906+
dk_python, dk_triton, accuracy_threshold
907+
)
908+
909+
# Analyze dV gradients
910+
print(f"\n🔍 Analyzing dV gradients:")
911+
is_dv_close, max_dv_diff, mean_dv_diff = analyze_differences(
912+
dv_python, dv_triton, accuracy_threshold
913+
)
914+
915+
# Analyze dBias gradients
916+
print(f"\n🔍 Analyzing dBias gradients:")
917+
is_dbias_close, max_dbias_diff, mean_dbias_diff = analyze_differences(
918+
dbias_python, dbias_triton, accuracy_threshold
919+
)
920+
921+
# Report performance difference
922+
speedup = py_time / triton_time if triton_time > 0 else float('inf')
923+
print(f"\n⚡ Performance comparison:")
924+
print(f" 🐍 Python implementation: {py_time*1000:.2f} ms")
925+
print(f" 🚀 Triton implementation: {triton_time*1000:.2f} ms")
926+
print(f" 📈 Speedup: {speedup:.2f}x")
927+
928+
# Check if all gradients pass
929+
is_close = (is_attn_output_close and is_dq_close and is_dk_close and is_dv_close and is_dbias_close)
930+
test_result = "Passed" if is_close else "Failed"
931+
result_icon = "✅" if is_close else "❌"
932+
all_passed = all_passed and is_close
933+
print(f"\n{result_icon} Test result: {test_result}")
934+
935+
# If test fails with large difference, can exit early
936+
if not is_close and max_attn_output_diff > 1e-2:
937+
print(" ⚠️ Difference too large, stopping subsequent tests.")
938+
break
939+
if not is_close and max_dq_diff > 1e-2:
940+
print(" ⚠️ Difference too large, stopping subsequent tests.")
941+
break
942+
if not is_close and max_dk_diff > 1e-2:
943+
print(" ⚠️ Difference too large, stopping subsequent tests.")
944+
break
945+
if not is_close and max_dv_diff > 1e-2:
946+
print(" ⚠️ Difference too large, stopping subsequent tests.")
947+
break
948+
if not is_close and max_dbias_diff > 1e-2:
949+
print(" ⚠️ Difference too large, stopping subsequent tests.")
950+
break
951+
del query_states, key_states, value_states, attn_bias, causal_mask, cache_position, dq_python, dk_python, dv_python, dbias_python, dq_triton, dk_triton, dv_triton, dbias_triton
952+
torch.cuda.empty_cache()
953+
gc.collect()
954+
torch.cuda.synchronize()
955+
956+
print("\n" + "🏁" + "=" * 76 + "🏁")
957+
summary_icon = "🎉" if all_passed else "😞"
958+
print(f"{summary_icon} Backward Equivalence Test Summary: {'All Passed' if all_passed else 'Some Tests Failed'}")
959+
print("🏁" + "=" * 76 + "🏁")
960+
961+
return all_passed
962+
732963
def main():
733964
"""
734965
Test backward pass equivalence between Python prototype and various implementations
@@ -782,9 +1013,9 @@ def main():
7821013
print("\n" + "📍" + " Starting Python vs CUDA Backward Tests " + "📍")
7831014
test_results['cuda'] = test_cuda_backward_equivalence(args.accuracy_threshold)
7841015

785-
# if args.test_type in ['all', 'triton']:
786-
# print("\n" + "🔥" + " Starting Python vs Triton Backward Tests " + "🔥")
787-
# test_results['triton'] = test_triton_backward_equivalence(args.accuracy_threshold)
1016+
if args.test_type in ['all', 'triton']:
1017+
print("\n" + "🔥" + " Starting Python vs Triton Backward Tests " + "🔥")
1018+
test_results['triton'] = test_triton_backward_equivalence(args.accuracy_threshold)
7881019

7891020
# if args.test_type in ['all', 'flex']:
7901021
# print("\n" + "🌟" + " Starting Python vs Flex Attention Backward Tests " + "🌟")

0 commit comments

Comments
 (0)