@@ -452,6 +452,7 @@ def __init__(
452452 self .rope_fusion = neox_args .rope_fusion
453453 self .attention_type = neox_args .attention_config [layer_number ]
454454 self .use_flash_attention = self .attention_type == "flash"
455+ self .use_ring_attention = self .attention_type == "ring"
455456 self .use_triton = (
456457 self .use_flash_attention
457458 and self .pos_emb == "alibi"
@@ -460,7 +461,7 @@ def __init__(
460461 >= packaging .version .Version ("2.4.0.post1" )
461462 )
462463 )
463- self .sparse = self .attention_type not in ("global" , "flash" )
464+ self .sparse = self .attention_type not in ("global" , "flash" , "ring" )
464465
465466 if self .gqa :
466467 assert not self .sparse
@@ -489,6 +490,12 @@ def __init__(
489490 self .flash_triton_fn = flash_attn_unpadded_unpacked_func_triton
490491 self .flash_qkv_fn = flash_attn_func
491492 self .flash_varlen_qkv_fn = flash_attn_varlen_func
493+ elif self .use_ring_attention :
494+ from ring_flash_attn .zigzag_ring_flash_attn import (
495+ zigzag_ring_flash_attn_func ,
496+ )
497+
498+ self .ring_attn_fn = zigzag_ring_flash_attn_func
492499 else :
493500 self .scale_mask_softmax = FusedScaleMaskSoftmax (
494501 input_in_fp16 = self .fp16 ,
@@ -736,6 +743,96 @@ def flash_attention(self, query_layer, key_layer, value_layer):
736743
737744 return matmul_result
738745
746+ def ring_attention (self , query_layer , key_layer , value_layer ):
747+ # [b, np, sq, sk]
748+ output_size = (
749+ query_layer .size (1 ),
750+ query_layer .size (2 ),
751+ query_layer .size (0 ),
752+ key_layer .size (0 ),
753+ )
754+
755+ # [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn]
756+ key_layer = key_layer .transpose (0 , 1 ).reshape (
757+ output_size [0 ], output_size [3 ], self .num_kv_heads_per_partition , - 1
758+ )
759+ value_layer = value_layer .transpose (0 , 1 ).reshape (
760+ output_size [0 ], output_size [3 ], self .num_kv_heads_per_partition , - 1
761+ )
762+
763+ # [sq, b, np, hn] -> [b, sq, np, hn]
764+ query_layer = query_layer .transpose (0 , 1 ).reshape (
765+ output_size [0 ], output_size [2 ], output_size [1 ], - 1
766+ )
767+
768+ # only pass in window_size or alibi_slopes kwarg
769+ # if we use Sliding Window Attention / AliBi.
770+ # Flash attn defaults to (-1,-1), or
771+ # does not have this kwarg prior to v2.3.0
772+ extra_kwargs = (
773+ {"window_size" : (self .sliding_window_width , - 1 )}
774+ if self .sliding_window_width is not None
775+ else {}
776+ )
777+ if self .pos_emb == "alibi" :
778+ extra_kwargs ["alibi_slopes" ] = self .alibi_embed .slopes .to (
779+ query_layer .device
780+ ).to (torch .float32 )
781+
782+ if not self .training :
783+ batch_size = output_size [0 ]
784+ max_seqlen_q = output_size [2 ]
785+ max_seqlen_k = output_size [3 ]
786+
787+ cu_seqlens_q = torch .arange (
788+ 0 ,
789+ (batch_size + 1 ) * max_seqlen_q ,
790+ step = max_seqlen_q ,
791+ dtype = torch .int32 ,
792+ device = query_layer .device ,
793+ )
794+
795+ cu_seqlens_k = torch .arange (
796+ 0 ,
797+ (batch_size + 1 ) * max_seqlen_k ,
798+ step = max_seqlen_k ,
799+ dtype = torch .int32 ,
800+ device = key_layer .device ,
801+ )
802+
803+ q_shape = query_layer .shape
804+ k_shape = key_layer .shape
805+ v_shape = value_layer .shape
806+ is_causal = max_seqlen_q == max_seqlen_k
807+ output = self .ring_attn_fn (
808+ query_layer ,
809+ key_layer ,
810+ value_layer ,
811+ 0.0 ,
812+ softmax_scale = None ,
813+ causal = is_causal ,
814+ group = mpu .get_context_parallel_group (),
815+ ** extra_kwargs ,
816+ )
817+ output = output .reshape (q_shape )
818+ else :
819+ output = self .ring_attn_fn (
820+ query_layer ,
821+ key_layer ,
822+ value_layer ,
823+ self .dropout_p if self .training else 0.0 ,
824+ softmax_scale = None ,
825+ causal = True ,
826+ group = mpu .get_context_parallel_group (),
827+ ** extra_kwargs ,
828+ )
829+
830+ matmul_result = output
831+ # [b, sq, np, hn] -> [b, np, sq, hn]
832+ matmul_result = matmul_result .transpose (1 , 2 )
833+
834+ return matmul_result
835+
739836 def sparse_attention (self , query_layer , key_layer , value_layer , attention_mask ):
740837 # TODO: sparse attn dropout?
741838 # TODO: pad to block size
@@ -831,7 +928,7 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None):
831928 value_layer = value_layer .view (* new_kv_shape )
832929
833930 # if not using Flash attention, we repeat K/V heads to match Q head counts
834- if not self .use_flash_attention :
931+ if not ( self .use_flash_attention or self . use_ring_attention ) :
835932 key_layer = torch .repeat_interleave (
836933 key_layer ,
837934 repeats = int (
@@ -945,6 +1042,8 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
9451042
9461043 if self .use_flash_attention :
9471044 context_layer = self .flash_attention (query_layer , key_layer , value_layer )
1045+ elif self .use_ring_attention :
1046+ context_layer = self .ring_attention (query_layer , key_layer , value_layer )
9481047 elif not self .sparse :
9491048 context_layer = self .attention (
9501049 query_layer , key_layer , value_layer , layer_past , attention_mask
0 commit comments