Skip to content

Commit befde51

Browse files
committed
outputs not matching non-flash case in MQA
1 parent 8c1889e commit befde51

File tree

2 files changed

+5
-13
lines changed

2 files changed

+5
-13
lines changed

megatron/model/transformer.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
398398
key_layer.size(0))
399399

400400
# [sq, b, np, hn] -> [b, np * sq, hn]
401-
query_layer = query_layer.permute([1, 2, 0, 3]).reshape(bs, np * sq, -1)
401+
query_layer = query_layer.transpose(0, 1).reshape(bs, np * sq, -1)
402402
# [sk, b, 1, hn] -> [b, hn, sk]
403403
key_layer = key_layer.squeeze(2).permute(1, 2, 0)
404404
# [sk, b, 1, hn] -> [sk, b * np, hn]
@@ -439,8 +439,8 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
439439
key_layer,
440440
beta=beta, alpha=(1.0 / self.norm_factor))
441441

442-
# change view to [b, np, sq, sk]
443-
attention_scores = matmul_result.view(bs, np, sq, sk)
442+
attention_scores = matmul_result.view(bs, sq, np, sk)
443+
attention_mask = attention_mask.transpose(1, 2)
444444

445445
# ===========================
446446
# Attention probs and dropout
@@ -482,15 +482,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
482482
context_layer = torch.bmm(attention_probs, value_layer)
483483

484484
# change view [b, np, sq, hn]
485-
context_layer = context_layer.view(bs, np, sq, -1)
486-
487-
# [b, np, sq, hn] --> [sq, b, np, hn]
488-
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
489-
490-
# [sq, b, np, hn] --> [sq, b, hp]
491-
new_context_layer_shape = context_layer.size()[:-2] + \
492-
(self.hidden_size_per_partition,)
493-
context_layer = context_layer.view(*new_context_layer_shape)
485+
context_layer = context_layer.view(bs, sq, -1).transpose(0, 1)
494486

495487
return context_layer
496488

tools/checkpoint_saver_megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def check_message(msg):
138138
if hasattr (md, 'checkpoint_args'):
139139
# These are arguments that we are either changing, or cause problems for validation if they are set
140140
# Note that some of these deal with T5 so will need to be changed if we support T5.
141-
args_to_keep = ['tensor_model_parallel_size', 'pipeline_model_parallel_size', 'params_dtype',
141+
args_to_keep = ['tensor_model_parallel_size', 'pipeline_model_parallel_size', 'world_size', 'params_dtype',
142142
'num_layers_per_virtual_pipeline_stage', 'virtual_pipeline_model_parallel_size',
143143
'masked_softmax_fusion', 'bias_gelu_fusion', 'bias_dropout_fusion',
144144
'sequence_parallel', 'async_tensor_model_parallel_allreduce',

0 commit comments

Comments
 (0)