@@ -3823,6 +3823,123 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
38233823
38243824 return inputs
38253825
3826+ def _is_attention_mask_causal (self , attention_mask ):
3827+ """
3828+ Check if an attention mask is causal (compatible with causal attention).
3829+ Context parallelism only supports causal attention patterns. This function
3830+ checks if the provided attention mask is compatible.
3831+
3832+ Args:
3833+ attention_mask (torch.Tensor): The attention mask to check
3834+
3835+ Returns:
3836+ bool: True if the mask is causal or compatible with causal attention
3837+ """
3838+ if attention_mask is None :
3839+ return True # No mask is considered causal (model uses default causal masking)
3840+
3841+ # Handle different mask dimensions
3842+ if attention_mask .dim () == 2 :
3843+ # (batch_size, seq_len) - standard padding mask, compatible with causal attention
3844+ return True
3845+ elif attention_mask .dim () in [3 , 4 ]:
3846+ # (batch_size, seq_len, seq_len) or (batch_size, num_heads, seq_len, seq_len)
3847+ # Check if it's lower triangular (causal)
3848+ seq_len = attention_mask .shape [- 1 ]
3849+ if seq_len <= 1 :
3850+ return True # Single token or empty is always causal
3851+
3852+ # Take first batch and head (if 4D) for checking pattern
3853+ if attention_mask .dim () == 4 :
3854+ mask = attention_mask [0 , 0 ] # First batch, first head
3855+ else :
3856+ mask = attention_mask [0 ] # First batch
3857+
3858+ # Check if upper triangular part is masked (should be 0 or very negative for causal)
3859+ upper_triangular = torch .triu (mask , diagonal = 1 )
3860+
3861+ # For causal masks, upper triangular should be 0 or very negative (like -inf)
3862+ # Use a reasonable threshold to handle float precision issues
3863+ is_causal = torch .all (upper_triangular <= 1e-6 ) or torch .all (upper_triangular < - 1e4 )
3864+ return is_causal .item () if isinstance (is_causal , torch .Tensor ) else is_causal
3865+
3866+ # For unknown dimensions, be conservative and reject
3867+ return False
3868+
3869+ def _prepare_context_parallel_inputs (self , model , inputs : dict [str , Union [torch .Tensor , Any ]]):
3870+ """
3871+ Prepare inputs for context parallelism by setting up buffers and validation.
3872+
3873+ Args:
3874+ model: The model being trained
3875+ inputs: Input tensors to prepare
3876+
3877+ Returns:
3878+ tuple: (context_manager, prepared_inputs) where context_manager is either
3879+ the context parallelism wrapper or a no-op context
3880+ """
3881+ if (
3882+ getattr (self .accelerator , "parallelism_config" , None ) is not None
3883+ and self .accelerator .parallelism_config .cp_enabled
3884+ ):
3885+ if hasattr (model , "config" ):
3886+ if model .config ._attn_implementation != "sdpa" :
3887+ raise ValueError (
3888+ f"Context parallelism is supported only with SDPA attention, you are using { model .config ._attn_implementation } ."
3889+ )
3890+
3891+ if "position_ids" not in inputs :
3892+ logger .warning_once ("Position IDs not found in the inputs, generating manually" )
3893+ inputs ["position_ids" ] = torch .arange (
3894+ inputs ["input_ids" ].size (1 ), device = inputs ["input_ids" ].device
3895+ ).expand (inputs ["input_ids" ].size (0 ), - 1 )
3896+ if "shift_labels" not in inputs :
3897+ logger .warning_once ("Shift labels not found in the inputs, shifting manually" )
3898+ if "labels" in inputs :
3899+ _ignore_index = - 100
3900+ labels = nn .functional .pad (inputs ["labels" ], (0 , 1 ), value = _ignore_index )
3901+ inputs ["shift_labels" ] = labels [:, 1 :].contiguous ()
3902+
3903+ buffers = []
3904+ buffer_seq_dims = []
3905+
3906+ if "input_ids" in inputs :
3907+ buffers .append (inputs ["input_ids" ])
3908+ buffer_seq_dims .append (1 ) # Sequence dimension
3909+ if "labels" in inputs :
3910+ buffers .append (inputs ["labels" ])
3911+ buffer_seq_dims .append (1 )
3912+ if "shift_labels" in inputs :
3913+ buffers .append (inputs ["shift_labels" ])
3914+ buffer_seq_dims .append (1 )
3915+ if "attention_mask" in inputs and not getattr (self , "_attn_mask_causal_checked" , False ):
3916+ # Context parallel currently doesn't support other masks than causal
3917+ # Accelerate applies hooks to replace mask with is_causal arg in SDPA
3918+ # Check if the mask is really causal and if not throw an error
3919+ # TODO: check this only once or always, with speed being the cost
3920+ attention_mask = inputs ["attention_mask" ]
3921+ if not self ._is_attention_mask_causal (attention_mask ):
3922+ raise ValueError (
3923+ "Context parallelism only supports causal attention masks. "
3924+ "The provided attention_mask is not causal. "
3925+ "Please ensure your data uses causal masking (lower triangular) "
3926+ "or remove the attention_mask to use the model's default causal masking."
3927+ )
3928+ self ._attn_mask_causal_checked = True
3929+ # Include position_ids in context parallelism splitting
3930+ if "position_ids" in inputs and inputs ["position_ids" ] is not None :
3931+ buffers .append (inputs ["position_ids" ])
3932+ buffer_seq_dims .append (1 )
3933+
3934+ return partial (
3935+ self .accelerator .maybe_context_parallel ,
3936+ buffers = buffers ,
3937+ buffer_seq_dims = buffer_seq_dims ,
3938+ no_restore_buffers = set (buffers ),
3939+ ), inputs
3940+
3941+ return contextlib .nullcontext , inputs
3942+
38263943 def compute_loss_context_manager (self ):
38273944 """
38283945 A helper wrapper to group together context managers.
@@ -3873,66 +3990,74 @@ def training_step(
38733990 Return:
38743991 `torch.Tensor`: The tensor with training loss on this batch.
38753992 """
3876- model .train ()
3877- if hasattr (self .optimizer , "train" ) and callable (self .optimizer .train ):
3878- self .optimizer .train ()
3993+ # Prepare buffers for context parallelism
38793994
3880- inputs = self ._prepare_inputs (inputs )
3881- if is_sagemaker_mp_enabled ():
3882- loss_mb = smp_forward_backward (model , inputs , self .args .gradient_accumulation_steps )
3883- return loss_mb .reduce_mean ().detach ().to (self .args .device )
3995+ cp_context , inputs = self ._prepare_context_parallel_inputs (model , inputs )
38843996
3885- with self .compute_loss_context_manager ():
3886- loss = self .compute_loss (model , inputs , num_items_in_batch = num_items_in_batch )
3997+ # Context manager is no-op if CP isn't enabled
3998+ with cp_context ():
3999+ model .train ()
4000+ if hasattr (self .optimizer , "train" ) and callable (self .optimizer .train ):
4001+ self .optimizer .train ()
38874002
3888- del inputs
3889- if (
3890- self .args .torch_empty_cache_steps is not None
3891- and self .state .global_step % self .args .torch_empty_cache_steps == 0
3892- ):
3893- if is_torch_xpu_available ():
3894- torch .xpu .empty_cache ()
3895- elif is_torch_mlu_available ():
3896- torch .mlu .empty_cache ()
3897- elif is_torch_musa_available ():
3898- torch .musa .empty_cache ()
3899- elif is_torch_npu_available ():
3900- torch .npu .empty_cache ()
3901- elif is_torch_mps_available ():
3902- torch .mps .empty_cache ()
3903- elif is_torch_hpu_available ():
3904- logger .warning (
3905- "`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
3906- )
3907- else :
3908- torch .cuda .empty_cache ()
4003+ inputs = self ._prepare_inputs (inputs )
4004+ if is_sagemaker_mp_enabled ():
4005+ loss_mb = smp_forward_backward (model , inputs , self .args .gradient_accumulation_steps )
4006+ return loss_mb .reduce_mean ().detach ().to (self .args .device )
39094007
3910- kwargs = {}
4008+ with self .compute_loss_context_manager ():
4009+ loss = self .compute_loss (model , inputs , num_items_in_batch = num_items_in_batch )
39114010
3912- # For LOMO optimizers you need to explicitly use the learning rate
3913- if self .args .optim in [OptimizerNames .LOMO , OptimizerNames .ADALOMO ]:
3914- kwargs ["learning_rate" ] = self ._get_learning_rate ()
4011+ del inputs
4012+ if (
4013+ self .args .torch_empty_cache_steps is not None
4014+ and self .state .global_step % self .args .torch_empty_cache_steps == 0
4015+ ):
4016+ if is_torch_xpu_available ():
4017+ torch .xpu .empty_cache ()
4018+ elif is_torch_mlu_available ():
4019+ torch .mlu .empty_cache ()
4020+ elif is_torch_musa_available ():
4021+ torch .musa .empty_cache ()
4022+ elif is_torch_npu_available ():
4023+ torch .npu .empty_cache ()
4024+ elif is_torch_mps_available ():
4025+ torch .mps .empty_cache ()
4026+ elif is_torch_hpu_available ():
4027+ logger .warning (
4028+ "`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
4029+ )
4030+ else :
4031+ torch .cuda .empty_cache ()
39154032
3916- if self .args .n_gpu > 1 :
3917- loss = loss .mean () # mean() to average on multi-gpu parallel training
4033+ kwargs = {}
39184034
3919- if self .use_apex :
3920- from apex import amp
4035+ # For LOMO optimizers you need to explicitly use the learning rate
4036+ if self .args .optim in [OptimizerNames .LOMO , OptimizerNames .ADALOMO ]:
4037+ kwargs ["learning_rate" ] = self ._get_learning_rate ()
39214038
3922- with amp .scale_loss (loss , self .optimizer ) as scaled_loss :
3923- scaled_loss .backward ()
3924- else :
3925- # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
3926- if (not self .model_accepts_loss_kwargs or num_items_in_batch is None ) and self .compute_loss_func is None :
3927- # If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
3928- loss = loss / self .current_gradient_accumulation_steps
4039+ if self .args .n_gpu > 1 :
4040+ loss = loss .mean () # mean() to average on multi-gpu parallel training
4041+
4042+ if self .use_apex :
4043+ from apex import amp
4044+
4045+ with amp .scale_loss (loss , self .optimizer ) as scaled_loss :
4046+ scaled_loss .backward ()
4047+ else :
4048+ # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
4049+ if (
4050+ not self .model_accepts_loss_kwargs or num_items_in_batch is None
4051+ ) and self .compute_loss_func is None :
4052+ # If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
4053+ loss = loss / self .current_gradient_accumulation_steps
39294054
3930- # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
3931- # https://github.com/huggingface/transformers/pull/35808
3932- if self .accelerator .distributed_type == DistributedType .DEEPSPEED :
3933- kwargs ["scale_wrt_gas" ] = False
4055+ # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
4056+ # https://github.com/huggingface/transformers/pull/35808
4057+ if self .accelerator .distributed_type == DistributedType .DEEPSPEED :
4058+ kwargs ["scale_wrt_gas" ] = False
39344059
3935- self .accelerator .backward (loss , ** kwargs )
4060+ self .accelerator .backward (loss , ** kwargs )
39364061
39374062 return loss .detach ()
39384063
@@ -5337,6 +5462,16 @@ def create_accelerator_and_postprocess(self):
53375462 args = {
53385463 "deepspeed_plugin" : self .args .deepspeed_plugin ,
53395464 }
5465+
5466+ # We defer compatibility checks to accelerator
5467+ if self .args .parallelism_config is not None :
5468+ if not is_accelerate_available ("1.10.1" ):
5469+ raise ImportError (
5470+ "ParallelismConfig requires accelerate v1.10.1 and above. Please upgrade accelerate to use this feature."
5471+ )
5472+
5473+ args ["parallelism_config" ] = self .args .parallelism_config
5474+
53405475 if is_accelerate_available ("0.28.0" ):
53415476 args ["dataloader_config" ] = dataloader_config
53425477 else :
@@ -5479,6 +5614,9 @@ def get_batch_samples(
54795614 if self .args .n_gpu > 1 and num_items_in_batch .dim () == 0 :
54805615 # In the DataParallel case, convert the scalar tensor into a 1-dim tensor
54815616 num_items_in_batch = num_items_in_batch .unsqueeze (0 )
5617+ # Divide by number of devices with the same batch
5618+ if pc := self .accelerator .parallelism_config :
5619+ num_items_in_batch = num_items_in_batch // pc .non_data_parallel_size
54825620
54835621 return batch_samples , num_items_in_batch
54845622
0 commit comments