| 
14 | 14 | from transformers.modeling_outputs import (  | 
15 | 15 |     BaseModelOutputWithPastAndCrossAttentions,  | 
16 | 16 |     CausalLMOutputWithCrossAttentions,  | 
 | 17 | +    CausalLMOutputWithPast,  | 
17 | 18 |     QuestionAnsweringModelOutput,  | 
18 | 19 |     SequenceClassifierOutputWithPast,  | 
19 | 20 |     TokenClassifierOutput,  | 
 | 
31 | 32 | from colossalai.pipeline.stage_manager import PipelineStageManager  | 
32 | 33 | from colossalai.shardformer.shard import ShardConfig  | 
33 | 34 | 
 
  | 
 | 35 | +from ..layer import cross_entropy_1d  | 
 | 36 | + | 
34 | 37 | 
 
  | 
35 | 38 | def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:  | 
36 | 39 |     def build_falcon_alibi_tensor(  | 
@@ -437,14 +440,28 @@ def falcon_for_causal_lm_forward(  | 
437 | 440 |             loss = None  | 
438 | 441 |             if labels is not None:  | 
439 | 442 |                 # Shift so that tokens < n predict n  | 
 | 443 | +                labels = labels.to(lm_logits.device)  | 
440 | 444 |                 shift_logits = lm_logits[..., :-1, :].contiguous()  | 
441 | 445 |                 shift_labels = labels[..., 1:].contiguous()  | 
442 | 446 |                 batch_size, seq_length, vocab_size = shift_logits.shape  | 
443 | 447 |                 # Flatten the tokens  | 
444 | 448 |                 loss_fct = CrossEntropyLoss()  | 
445 |  | -                loss = loss_fct(  | 
446 |  | -                    shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)  | 
447 |  | -                )  | 
 | 449 | +                if shard_config.enable_tensor_parallelism and shard_config.parallel_output:  | 
 | 450 | +                    new_vocab_size = shift_logits.shape[-1]  | 
 | 451 | +                    shift_logits = shift_logits.view(-1, new_vocab_size)  | 
 | 452 | +                    shift_labels = shift_labels.view(-1)  | 
 | 453 | +                    loss = cross_entropy_1d(  | 
 | 454 | +                        shift_logits,  | 
 | 455 | +                        shift_labels,  | 
 | 456 | +                        process_group=shard_config.tensor_parallel_process_group,  | 
 | 457 | +                        vocab_size=self.lm_head.out_features,  | 
 | 458 | +                        dtype=self.transformer.dtype,  | 
 | 459 | +                    )  | 
 | 460 | +                else:  | 
 | 461 | +                    loss = loss_fct(  | 
 | 462 | +                        shift_logits.view(batch_size * seq_length, vocab_size),  | 
 | 463 | +                        shift_labels.view(batch_size * seq_length),  | 
 | 464 | +                    )  | 
448 | 465 | 
 
  | 
449 | 466 |             if not return_dict:  | 
450 | 467 |                 output = (lm_logits,) + transformer_outputs[1:]  | 
@@ -747,3 +764,79 @@ def falcon_for_question_answering_forward(  | 
747 | 764 |         else:  | 
748 | 765 |             hidden_states = outputs.get("hidden_states")  | 
749 | 766 |             return {"hidden_states": hidden_states}  | 
 | 767 | + | 
 | 768 | + | 
 | 769 | +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):  | 
 | 770 | +    from transformers import FalconForCausalLM  | 
 | 771 | + | 
 | 772 | +    def forward(  | 
 | 773 | +        self: FalconForCausalLM,  | 
 | 774 | +        input_ids: Optional[torch.LongTensor] = None,  | 
 | 775 | +        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,  | 
 | 776 | +        attention_mask: Optional[torch.Tensor] = None,  | 
 | 777 | +        head_mask: Optional[torch.Tensor] = None,  | 
 | 778 | +        inputs_embeds: Optional[torch.Tensor] = None,  | 
 | 779 | +        labels: Optional[torch.Tensor] = None,  | 
 | 780 | +        use_cache: Optional[bool] = None,  | 
 | 781 | +        output_attentions: Optional[bool] = None,  | 
 | 782 | +        output_hidden_states: Optional[bool] = None,  | 
 | 783 | +        return_dict: Optional[bool] = None,  | 
 | 784 | +    ) -> Union[Tuple, CausalLMOutputWithPast]:  | 
 | 785 | +        r"""  | 
 | 786 | +        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):  | 
 | 787 | +            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set  | 
 | 788 | +            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`  | 
 | 789 | +            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`  | 
 | 790 | +        """  | 
 | 791 | +        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions  | 
 | 792 | +        output_hidden_states = (  | 
 | 793 | +            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states  | 
 | 794 | +        )  | 
 | 795 | +        return_dict = return_dict if return_dict is not None else self.config.use_return_dict  | 
 | 796 | + | 
 | 797 | +        transformer_outputs = self.transformer(  | 
 | 798 | +            input_ids,  | 
 | 799 | +            past_key_values=past_key_values,  | 
 | 800 | +            attention_mask=attention_mask,  | 
 | 801 | +            head_mask=head_mask,  | 
 | 802 | +            inputs_embeds=inputs_embeds,  | 
 | 803 | +            use_cache=use_cache,  | 
 | 804 | +            output_attentions=output_attentions,  | 
 | 805 | +            output_hidden_states=output_hidden_states,  | 
 | 806 | +            return_dict=return_dict,  | 
 | 807 | +        )  | 
 | 808 | +        past_key_values = None  | 
 | 809 | +        hidden_states = transformer_outputs[0]  | 
 | 810 | +        lm_logits = self.lm_head(hidden_states)  | 
 | 811 | +        loss = None  | 
 | 812 | +        if labels is not None:  | 
 | 813 | +            # Shift so that tokens < n predict n  | 
 | 814 | +            labels = labels.to(lm_logits.device)  | 
 | 815 | +            shift_logits = lm_logits[..., :-1, :].contiguous()  | 
 | 816 | +            shift_labels = labels[..., 1:].contiguous()  | 
 | 817 | +            batch_size, seq_length, vocab_size = shift_logits.shape  | 
 | 818 | +            # Flatten the tokens  | 
 | 819 | +            new_vocab_size = shift_logits.shape[-1]  | 
 | 820 | +            shift_logits = shift_logits.view(-1, new_vocab_size)  | 
 | 821 | +            shift_labels = shift_labels.view(-1)  | 
 | 822 | +            loss = cross_entropy_1d(  | 
 | 823 | +                shift_logits,  | 
 | 824 | +                shift_labels,  | 
 | 825 | +                process_group=shard_config.tensor_parallel_process_group,  | 
 | 826 | +                vocab_size=self.lm_head.out_features,  | 
 | 827 | +                dtype=self.transformer.dtype,  | 
 | 828 | +            )  | 
 | 829 | + | 
 | 830 | +        if not return_dict:  | 
 | 831 | +            output = (lm_logits,) + transformer_outputs[1:]  | 
 | 832 | +            return ((loss,) + output) if loss is not None else output  | 
 | 833 | + | 
 | 834 | +        return CausalLMOutputWithPast(  | 
 | 835 | +            loss=loss,  | 
 | 836 | +            logits=lm_logits,  | 
 | 837 | +            past_key_values=transformer_outputs.past_key_values,  | 
 | 838 | +            hidden_states=transformer_outputs.hidden_states,  | 
 | 839 | +            attentions=transformer_outputs.attentions,  | 
 | 840 | +        )  | 
 | 841 | + | 
 | 842 | +    return forward  | 
0 commit comments