@@ -283,11 +283,9 @@ def dynamic_mask_attention_triton(
283283 attn_bias = repeat_kv (attn_bias_leaf , num_queries_per_kv )
284284
285285 # Ensure correct data types and memory layout for Triton function
286- query_states = query_states .transpose (1 , 2 ).contiguous () # [batch, query_len, num_heads, head_dim]
287- key_states = key_states .transpose (1 , 2 ).contiguous () # [batch, key_len, num_heads, head_dim]
288- value_states = value_states .transpose (1 , 2 ).contiguous () # [batch, key_len, num_heads, head_dim]
289- attn_mask = attn_mask .contiguous () # [batch, num_heads, seqlen_q, seqlen_k]
290- attn_bias = attn_bias .contiguous () # [batch, num_heads, seqlen_q, seqlen_k]
286+ query_states = query_states .transpose (1 , 2 ) # [batch, query_len, num_heads, head_dim]
287+ key_states = key_states .transpose (1 , 2 ) # [batch, key_len, num_heads, head_dim]
288+ value_states = value_states .transpose (1 , 2 ) # [batch, key_len, num_heads, head_dim]
291289
292290 # Call the Triton implementation
293291 attn_outputs = triton_dmattn_func (
@@ -729,6 +727,239 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
729727 return all_passed
730728
731729
730+ def test_triton_backward_equivalence (accuracy_threshold = 0.95 ):
731+ """Test backward pass equivalence between Python prototype and Triton implementation."""
732+ print ("\n " + "🚀" + "=" * 76 + "🚀" )
733+ print ("🔬 Testing backward Pass Equivalence: Python Prototype vs Triton Implementation" )
734+ print ("🚀" + "=" * 76 + "🚀" )
735+
736+ # Check if Triton implementation is available
737+ if triton_dmattn_func is None :
738+ print ("❌ Triton implementation not available, skipping test." )
739+ return False
740+
741+ # Set random seed for reproducibility
742+ torch .manual_seed (0 )
743+
744+ # Test different parameter configurations
745+ # If you encounter NAN issues when running multiple configurations, try running a single configuration
746+ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
747+ test_configs = [
748+ # Head dim 32
749+ (1 , 2 , 1 , 128 , 128 , 32 , False ),
750+ (1 , 2 , 1 , 128 , 128 , 32 , True ),
751+ (1 , 2 , 1 , 256 , 256 , 32 , False ),
752+ (1 , 2 , 1 , 256 , 256 , 32 , True ),
753+ (1 , 2 , 1 , 512 , 512 , 32 , False ),
754+ (1 , 2 , 1 , 512 , 512 , 32 , True ),
755+ (1 , 2 , 1 , 1024 , 1024 , 32 , False ),
756+ (1 , 2 , 1 , 1024 , 1024 , 32 , True ),
757+ (1 , 2 , 1 , 2048 , 2048 , 32 , False ),
758+ (1 , 2 , 1 , 2048 , 2048 , 32 , True ),
759+ (1 , 2 , 1 , 4096 , 4096 , 32 , False ),
760+ (1 , 2 , 1 , 4096 , 4096 , 32 , True ),
761+
762+ # Head dim 64
763+ (1 , 2 , 1 , 128 , 128 , 64 , False ),
764+ (1 , 2 , 1 , 128 , 128 , 64 , True ),
765+ (1 , 2 , 1 , 256 , 256 , 64 , False ),
766+ (1 , 2 , 1 , 256 , 256 , 64 , True ),
767+ (1 , 2 , 1 , 512 , 512 , 64 , False ),
768+ (1 , 2 , 1 , 512 , 512 , 64 , True ),
769+ (1 , 2 , 1 , 1024 , 1024 , 64 , False ),
770+ (1 , 2 , 1 , 1024 , 1024 , 64 , True ),
771+ (1 , 2 , 1 , 2048 , 2048 , 64 , False ),
772+ (1 , 2 , 1 , 2048 , 2048 , 64 , True ),
773+ (1 , 2 , 1 , 4096 , 4096 , 64 , False ),
774+ (1 , 2 , 1 , 4096 , 4096 , 64 , True ),
775+
776+ # Head dim 96
777+ (1 , 2 , 1 , 128 , 128 , 96 , False ),
778+ (1 , 2 , 1 , 128 , 128 , 96 , True ),
779+ (1 , 2 , 1 , 256 , 256 , 96 , False ),
780+ (1 , 2 , 1 , 256 , 256 , 96 , True ),
781+ (1 , 2 , 1 , 512 , 512 , 96 , False ),
782+ (1 , 2 , 1 , 512 , 512 , 96 , True ),
783+ (1 , 2 , 1 , 1024 , 1024 , 96 , False ),
784+ (1 , 2 , 1 , 1024 , 1024 , 96 , True ),
785+ (1 , 2 , 1 , 2048 , 2048 , 96 , False ),
786+ (1 , 2 , 1 , 2048 , 2048 , 96 , True ),
787+ (1 , 2 , 1 , 4096 , 4096 , 96 , False ),
788+ (1 , 2 , 1 , 4096 , 4096 , 96 , True ),
789+
790+ # Head dim 128
791+ (1 , 2 , 1 , 128 , 128 , 128 , False ),
792+ (1 , 2 , 1 , 128 , 128 , 128 , True ),
793+ (1 , 2 , 1 , 256 , 256 , 128 , False ),
794+ (1 , 2 , 1 , 256 , 256 , 128 , True ),
795+ (1 , 2 , 1 , 512 , 512 , 128 , False ),
796+ (1 , 2 , 1 , 512 , 512 , 128 , True ),
797+ (1 , 2 , 1 , 1024 , 1024 , 128 , False ),
798+ (1 , 2 , 1 , 1024 , 1024 , 128 , True ),
799+ (1 , 2 , 1 , 2048 , 2048 , 128 , False ),
800+ (1 , 2 , 1 , 2048 , 2048 , 128 , True ),
801+ (1 , 2 , 1 , 4096 , 4096 , 128 , False ),
802+ (1 , 2 , 1 , 4096 , 4096 , 128 , True ),
803+
804+ # triton currently supports up to head dim 128
805+ ]
806+
807+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
808+ dtype = torch .bfloat16
809+ device_icon = "🔥" if device .type == "cuda" else "💻"
810+ print (f"{ device_icon } Using device: { device } " )
811+
812+ all_passed = True
813+
814+ for i , config in enumerate (test_configs ):
815+ torch .cuda .empty_cache ()
816+ gc .collect ()
817+ torch .cuda .synchronize ()
818+
819+ batch_size , num_heads , num_kv_heads , query_len , key_len , head_dim , is_causal = config
820+
821+ # Progress indicator
822+ progress_filled = "█" * (i + 1 )
823+ progress_empty = "░" * (len (test_configs ) - i - 1 )
824+ progress_bar = f"[{ progress_filled } { progress_empty } ]"
825+
826+ print (f"\n 🧪 Test configuration { i + 1 } /{ len (test_configs )} { progress_bar } " )
827+ print (f" 📊 batch_size={ batch_size } , num_heads={ num_heads } , num_kv_heads={ num_kv_heads } " )
828+ print (f" 📏 query_len={ query_len } , key_len={ key_len } , head_dim={ head_dim } " )
829+ print (f" 🔒 is_causal={ is_causal } " )
830+ print (f" 🎯 Accuracy threshold: { accuracy_threshold * 100 :.1f} %" )
831+
832+ # Create random input data
833+ query_states = torch .randn (
834+ batch_size , num_heads , query_len , head_dim ,
835+ device = device , dtype = dtype , requires_grad = True
836+ )
837+ key_states = torch .randn (
838+ batch_size , num_kv_heads , key_len , head_dim ,
839+ device = device , dtype = dtype , requires_grad = True
840+ )
841+ value_states = torch .randn (
842+ batch_size , num_kv_heads , key_len , head_dim ,
843+ device = device , dtype = dtype , requires_grad = True
844+ )
845+ attn_bias = torch .randn (
846+ batch_size , num_kv_heads , query_len , key_len ,
847+ device = device , dtype = torch .bfloat16
848+ )
849+ cache_position = torch .arange (key_len - query_len , key_len , device = device )
850+ causal_mask = torch .arange (key_len , device = device ) <= cache_position .reshape (- 1 , 1 )
851+ causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
852+
853+ # Set scaling factor and keep window size
854+ scaling = head_dim ** - 0.5
855+ window_size = 10240
856+
857+ # Clone inputs for Python implementation
858+ query_python = query_states .clone ().detach ().requires_grad_ (True )
859+ key_python = key_states .clone ().detach ().requires_grad_ (True )
860+ value_python = value_states .clone ().detach ().requires_grad_ (True )
861+ attn_bias_python = attn_bias .clone ().detach ().requires_grad_ (True )
862+ causal_mask_python = causal_mask .clone ().detach ()
863+
864+ # Run Python implementation
865+ start_time = time .time ()
866+ attn_outputs_python , dq_python , dk_python , dv_python , dbias_python = dynamic_mask_attention_python (
867+ query_python , key_python , value_python ,
868+ attn_bias_python , causal_mask_python ,
869+ scaling , window_size , is_causal
870+ )
871+ torch .cuda .synchronize ()
872+ py_time = time .time () - start_time
873+
874+ # Clone inputs for Triton implementation
875+ query_triton = query_states .clone ().detach ().requires_grad_ (True )
876+ key_triton = key_states .clone ().detach ().requires_grad_ (True )
877+ value_triton = value_states .clone ().detach ().requires_grad_ (True )
878+ attn_bias_triton = attn_bias .clone ().detach ().requires_grad_ (True )
879+ causal_mask_triton = causal_mask .clone ().detach ()
880+
881+ # Run Triton implementation
882+ start_time = time .time ()
883+ attn_outputs_triton , dq_triton , dk_triton , dv_triton , dbias_triton = dynamic_mask_attention_triton (
884+ query_triton , key_triton , value_triton ,
885+ attn_bias_triton , causal_mask_triton ,
886+ scaling , window_size , is_causal
887+ )
888+ torch .cuda .synchronize ()
889+ triton_time = time .time () - start_time
890+
891+ # Analyze outputs
892+ print (f"\n 🔍 Analyzing differences between Python and Triton outputs:" )
893+ is_attn_output_close , max_attn_output_diff , mean_attn_output_diff = analyze_differences (
894+ attn_outputs_python , attn_outputs_triton , accuracy_threshold
895+ )
896+
897+ # Analyze dQ gradients
898+ print (f"\n 🔍 Analyzing dQ gradients:" )
899+ is_dq_close , max_dq_diff , mean_dq_diff = analyze_differences (
900+ dq_python , dq_triton , accuracy_threshold
901+ )
902+
903+ # Analyze dK gradients
904+ print (f"\n 🔍 Analyzing dK gradients:" )
905+ is_dk_close , max_dk_diff , mean_dk_diff = analyze_differences (
906+ dk_python , dk_triton , accuracy_threshold
907+ )
908+
909+ # Analyze dV gradients
910+ print (f"\n 🔍 Analyzing dV gradients:" )
911+ is_dv_close , max_dv_diff , mean_dv_diff = analyze_differences (
912+ dv_python , dv_triton , accuracy_threshold
913+ )
914+
915+ # Analyze dBias gradients
916+ print (f"\n 🔍 Analyzing dBias gradients:" )
917+ is_dbias_close , max_dbias_diff , mean_dbias_diff = analyze_differences (
918+ dbias_python , dbias_triton , accuracy_threshold
919+ )
920+
921+ # Report performance difference
922+ speedup = py_time / triton_time if triton_time > 0 else float ('inf' )
923+ print (f"\n ⚡ Performance comparison:" )
924+ print (f" 🐍 Python implementation: { py_time * 1000 :.2f} ms" )
925+ print (f" 🚀 Triton implementation: { triton_time * 1000 :.2f} ms" )
926+ print (f" 📈 Speedup: { speedup :.2f} x" )
927+
928+ # Check if all gradients pass
929+ is_close = (is_attn_output_close and is_dq_close and is_dk_close and is_dv_close and is_dbias_close )
930+ test_result = "Passed" if is_close else "Failed"
931+ result_icon = "✅" if is_close else "❌"
932+ all_passed = all_passed and is_close
933+ print (f"\n { result_icon } Test result: { test_result } " )
934+
935+ # If test fails with large difference, can exit early
936+ if not is_close and max_attn_output_diff > 1e-2 :
937+ print (" ⚠️ Difference too large, stopping subsequent tests." )
938+ break
939+ if not is_close and max_dq_diff > 1e-2 :
940+ print (" ⚠️ Difference too large, stopping subsequent tests." )
941+ break
942+ if not is_close and max_dk_diff > 1e-2 :
943+ print (" ⚠️ Difference too large, stopping subsequent tests." )
944+ break
945+ if not is_close and max_dv_diff > 1e-2 :
946+ print (" ⚠️ Difference too large, stopping subsequent tests." )
947+ break
948+ if not is_close and max_dbias_diff > 1e-2 :
949+ print (" ⚠️ Difference too large, stopping subsequent tests." )
950+ break
951+ del query_states , key_states , value_states , attn_bias , causal_mask , cache_position , dq_python , dk_python , dv_python , dbias_python , dq_triton , dk_triton , dv_triton , dbias_triton
952+ torch .cuda .empty_cache ()
953+ gc .collect ()
954+ torch .cuda .synchronize ()
955+
956+ print ("\n " + "🏁" + "=" * 76 + "🏁" )
957+ summary_icon = "🎉" if all_passed else "😞"
958+ print (f"{ summary_icon } Backward Equivalence Test Summary: { 'All Passed' if all_passed else 'Some Tests Failed' } " )
959+ print ("🏁" + "=" * 76 + "🏁" )
960+
961+ return all_passed
962+
732963def main ():
733964 """
734965 Test backward pass equivalence between Python prototype and various implementations
@@ -782,9 +1013,9 @@ def main():
7821013 print ("\n " + "📍" + " Starting Python vs CUDA Backward Tests " + "📍" )
7831014 test_results ['cuda' ] = test_cuda_backward_equivalence (args .accuracy_threshold )
7841015
785- # if args.test_type in ['all', 'triton']:
786- # print("\n" + "🔥" + " Starting Python vs Triton Backward Tests " + "🔥")
787- # test_results['triton'] = test_triton_backward_equivalence(args.accuracy_threshold)
1016+ if args .test_type in ['all' , 'triton' ]:
1017+ print ("\n " + "🔥" + " Starting Python vs Triton Backward Tests " + "🔥" )
1018+ test_results ['triton' ] = test_triton_backward_equivalence (args .accuracy_threshold )
7881019
7891020 # if args.test_type in ['all', 'flex']:
7901021 # print("\n" + "🌟" + " Starting Python vs Flex Attention Backward Tests " + "🌟")
0 commit comments