1616from colossalai .shardformer .shard import ShardConfig
1717
1818from ..layer import cross_entropy_1d
19- from ..layer ._operation import gather_forward_split_backward
2019
2120try :
2221 from transformers .models .llama .modeling_llama import _prepare_4d_causal_attention_mask
@@ -279,7 +278,7 @@ def llama_for_causal_lm_forward(
279278 shift_labels = shift_labels .view (- 1 )
280279 # Enable model parallelism
281280 shift_labels = shift_labels .to (shift_logits .device )
282- if shard_config .enable_tensor_parallelism :
281+ if shard_config .enable_tensor_parallelism and shard_config . parallel_output :
283282 new_vocab_size = logits .shape [- 1 ]
284283 shift_logits = shift_logits .view (- 1 , new_vocab_size )
285284 loss = cross_entropy_1d (
@@ -289,9 +288,6 @@ def llama_for_causal_lm_forward(
289288 shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
290289 loss = loss_fct (shift_logits , shift_labels )
291290
292- if not shard_config .parallel_output :
293- logits = gather_forward_split_backward (logits , - 1 , shard_config .tensor_parallel_process_group )
294-
295291 if not return_dict :
296292 output = (logits ,) + outputs [1 :]
297293 return (loss ,) + output if loss is not None else output
@@ -578,23 +574,15 @@ def forward(
578574 # Shift so that tokens < n predict n
579575 shift_logits = logits [..., :- 1 , :].contiguous ()
580576 shift_labels = labels [..., 1 :].contiguous ()
581- # Flatten the tokens
582- loss_fct = CrossEntropyLoss ()
583577 shift_labels = shift_labels .view (- 1 )
584578 # Enable model parallelism
585579 shift_labels = shift_labels .to (shift_logits .device )
586- if shard_config .enable_tensor_parallelism :
587- new_vocab_size = logits .shape [- 1 ]
588- shift_logits = shift_logits .view (- 1 , new_vocab_size )
589- loss = cross_entropy_1d (
590- shift_logits , shift_labels , process_group = shard_config .tensor_parallel_process_group
591- )
592- else :
593- shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
594- loss = loss_fct (shift_logits , shift_labels )
595580
596- if not shard_config .parallel_output :
597- logits = gather_forward_split_backward (logits , - 1 , shard_config .tensor_parallel_process_group )
581+ new_vocab_size = logits .shape [- 1 ]
582+ shift_logits = shift_logits .view (- 1 , new_vocab_size )
583+ loss = cross_entropy_1d (
584+ shift_logits , shift_labels , process_group = shard_config .tensor_parallel_process_group
585+ )
598586
599587 if not return_dict :
600588 output = (logits ,) + outputs [1 :]
0 commit comments