@@ -398,7 +398,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
398398 key_layer .size (0 ))
399399
400400 # [sq, b, np, hn] -> [b, np * sq, hn]
401- query_layer = query_layer .permute ([ 1 , 2 , 0 , 3 ] ).reshape (bs , np * sq , - 1 )
401+ query_layer = query_layer .transpose ( 0 , 1 ).reshape (bs , np * sq , - 1 )
402402 # [sk, b, 1, hn] -> [b, hn, sk]
403403 key_layer = key_layer .squeeze (2 ).permute (1 , 2 , 0 )
404404 # [sk, b, 1, hn] -> [sk, b * np, hn]
@@ -439,8 +439,8 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
439439 key_layer ,
440440 beta = beta , alpha = (1.0 / self .norm_factor ))
441441
442- # change view to [b , np, sq, sk]
443- attention_scores = matmul_result . view ( bs , np , sq , sk )
442+ attention_scores = matmul_result . view ( bs , sq , np , sk )
443+ attention_mask = attention_mask . transpose ( 1 , 2 )
444444
445445 # ===========================
446446 # Attention probs and dropout
@@ -482,15 +482,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
482482 context_layer = torch .bmm (attention_probs , value_layer )
483483
484484 # change view [b, np, sq, hn]
485- context_layer = context_layer .view (bs , np , sq , - 1 )
486-
487- # [b, np, sq, hn] --> [sq, b, np, hn]
488- context_layer = context_layer .permute (2 , 0 , 1 , 3 ).contiguous ()
489-
490- # [sq, b, np, hn] --> [sq, b, hp]
491- new_context_layer_shape = context_layer .size ()[:- 2 ] + \
492- (self .hidden_size_per_partition ,)
493- context_layer = context_layer .view (* new_context_layer_shape )
485+ context_layer = context_layer .view (bs , sq , - 1 ).transpose (0 , 1 )
494486
495487 return context_layer
496488
0 commit comments