File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed
colossalai/shardformer/modeling Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change 2525from colossalai .shardformer .shard import ShardConfig
2626
2727from ..layer import cross_entropy_1d
28- from ..layer ._operation import _gather
28+ from ..layer ._operation import gather_forward_split_backward
2929
3030
3131class GPT2PipelineForwards :
@@ -339,7 +339,7 @@ def gpt2_lmhead_model_forward(
339339 loss = loss_fct (shift_logits , shift_labels )
340340
341341 if not shard_config .parallel_output :
342- lm_logits = _gather (lm_logits , - 1 , shard_config .tensor_parallel_process_group )
342+ lm_logits = gather_forward_split_backward (lm_logits , - 1 , shard_config .tensor_parallel_process_group )
343343
344344 if not return_dict :
345345 output = (lm_logits ,) + outputs [1 :]
@@ -1089,7 +1089,7 @@ def forward(
10891089 loss = loss_fct (shift_logits , shift_labels )
10901090
10911091 if not shard_config .parallel_output :
1092- lm_logits = _gather (lm_logits , - 1 , shard_config .tensor_parallel_process_group )
1092+ lm_logits = gather_forward_split_backward (lm_logits , - 1 , shard_config .tensor_parallel_process_group )
10931093
10941094 if not return_dict :
10951095 output = (lm_logits ,) + transformer_outputs [1 :]
Original file line number Diff line number Diff line change 1616from colossalai .shardformer .shard import ShardConfig
1717
1818from ..layer import cross_entropy_1d
19- from ..layer ._operation import _gather
19+ from ..layer ._operation import gather_forward_split_backward
2020
2121try :
2222 from transformers .models .llama .modeling_llama import _prepare_4d_causal_attention_mask
@@ -290,7 +290,7 @@ def llama_for_causal_lm_forward(
290290 loss = loss_fct (shift_logits , shift_labels )
291291
292292 if not shard_config .parallel_output :
293- logits = _gather (logits , - 1 , shard_config .tensor_parallel_process_group )
293+ logits = gather_forward_split_backward (logits , - 1 , shard_config .tensor_parallel_process_group )
294294
295295 if not return_dict :
296296 output = (logits ,) + outputs [1 :]
@@ -594,7 +594,7 @@ def forward(
594594 loss = loss_fct (shift_logits , shift_labels )
595595
596596 if not shard_config .parallel_output :
597- logits = _gather (logits , - 1 , shard_config .tensor_parallel_process_group )
597+ logits = gather_forward_split_backward (logits , - 1 , shard_config .tensor_parallel_process_group )
598598
599599 if not return_dict :
600600 output = (logits ,) + outputs [1 :]
You can’t perform that action at this time.
0 commit comments