Skip to content

Commit 21fd304

Browse files
committed
fix gather output
1 parent 349c818 commit 21fd304

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

colossalai/shardformer/modeling/gpt2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from colossalai.shardformer.shard import ShardConfig
2626

2727
from ..layer import cross_entropy_1d
28-
from ..layer._operation import _gather
28+
from ..layer._operation import gather_forward_split_backward
2929

3030

3131
class GPT2PipelineForwards:
@@ -339,7 +339,7 @@ def gpt2_lmhead_model_forward(
339339
loss = loss_fct(shift_logits, shift_labels)
340340

341341
if not shard_config.parallel_output:
342-
lm_logits = _gather(lm_logits, -1, shard_config.tensor_parallel_process_group)
342+
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
343343

344344
if not return_dict:
345345
output = (lm_logits,) + outputs[1:]
@@ -1089,7 +1089,7 @@ def forward(
10891089
loss = loss_fct(shift_logits, shift_labels)
10901090

10911091
if not shard_config.parallel_output:
1092-
lm_logits = _gather(lm_logits, -1, shard_config.tensor_parallel_process_group)
1092+
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
10931093

10941094
if not return_dict:
10951095
output = (lm_logits,) + transformer_outputs[1:]

colossalai/shardformer/modeling/llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from colossalai.shardformer.shard import ShardConfig
1717

1818
from ..layer import cross_entropy_1d
19-
from ..layer._operation import _gather
19+
from ..layer._operation import gather_forward_split_backward
2020

2121
try:
2222
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
@@ -290,7 +290,7 @@ def llama_for_causal_lm_forward(
290290
loss = loss_fct(shift_logits, shift_labels)
291291

292292
if not shard_config.parallel_output:
293-
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
293+
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
294294

295295
if not return_dict:
296296
output = (logits,) + outputs[1:]
@@ -594,7 +594,7 @@ def forward(
594594
loss = loss_fct(shift_logits, shift_labels)
595595

596596
if not shard_config.parallel_output:
597-
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
597+
logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group)
598598

599599
if not return_dict:
600600
output = (logits,) + outputs[1:]

0 commit comments

Comments
 (0)