@@ -564,7 +564,6 @@ def benchmark(func, *args, **kwargs):
564564# 
565565# * Cross Attention 
566566# * Fully masked rows no longer cause NaNs 
567- # * Modifying attention score: ALiBi with FlexAttention and NJT 
568567# * Packed Projection 
569568
570569############################################################################### 
@@ -668,66 +667,6 @@ def benchmark(func, *args, **kwargs):
668667# appropriately makes it possible to properly express empty sequences. 
669668
670669
671- ################################################################################ 
672- # FlexAttention + NJT 
673- # --------------------------------------------------------------------- 
674- # NJT also composes with the ``FlexAttention`` module. This is a generalization 
675- # of the ``MultiheadAttention`` layer that allows for arbitrary modifications 
676- # to the attention score. The example below takes the ``alibi_mod`` 
677- # that implements `ALiBi <https://arxiv.org/abs/2108.12409>`_ from 
678- # `attention gym <https://github.com/meta-pytorch/attention-gym>`_ and uses it 
679- # with nested input tensors. 
680- 
681- from  torch .nn .attention .flex_attention  import  flex_attention 
682- 
683- 
684- def  generate_alibi_bias (H : int ):
685-     """Returns an alibi bias score_mod given the number of heads H 
686-     Args: 
687-         H: number of heads 
688-     Returns: 
689-         alibi_bias: alibi bias score_mod 
690-     """ 
691- 
692-     def  alibi_mod (score , b , h , q_idx , kv_idx ):
693-         scale  =  torch .exp2 (- ((h  +  1 ) *  8.0  /  H ))
694-         bias  =  (q_idx  -  kv_idx ) *  scale 
695-         return  score  +  bias 
696- 
697-     return  alibi_mod 
698- 
699- 
700- query , key , value , _  =  gen_batch (N , E_q , E_k , E_v , device )
701- n_heads , D  =  8 , E_q  //  8 
702- alibi_score_mod  =  generate_alibi_bias (n_heads )
703- query  =  query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
704- key  =  key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
705- value  =  value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
706- out_flex2  =  flex_attention (query , key , value , score_mod = alibi_score_mod )
707- 
708- ############################################################################### 
709- # In addition, one can also use the ``block_mask`` utility of ``FlexAttention`` 
710- # with NJTs via the ``create_nested_block_mask`` function. This is useful for 
711- # taking advantage of the sparsity of the mask to speed up the attention computation. 
712- # In particular, the function creates a sparse block mask for a "stacked sequence" of all 
713- # the variable length sequences in the NJT combined into one, while properly masking out 
714- # inter-sequence attention. In the following example, we show how to create a 
715- # causal block mask using this utility. 
716- 
717- from  torch .nn .attention .flex_attention  import  create_nested_block_mask 
718- 
719- 
720- def  causal_mask (b , h , q_idx , kv_idx ):
721-     return  q_idx  >=  kv_idx 
722- 
723- 
724- query , key , value , _  =  gen_batch (N , E_q , E_k , E_v , device )
725- block_mask  =  create_nested_block_mask (causal_mask , 1 , 1 , query , _compile = True )
726- query  =  query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
727- key  =  key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
728- value  =  value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
729- out_flex  =  flex_attention (query , key , value , block_mask = block_mask )
730- 
731670############################################################################### 
732671# Packed Projection 
733672# ----------------- 
0 commit comments