Skip to content

Commit 6d2bb1e

Browse files
kashifS1ro1SunMarc
authored
[Trainer] accelerate contextparallel support in trainer (huggingface#40205)
* initial context_parallel_size support in trainer * For context parallelism, use AVG instead of SUM to avoid over-accounting tokens * use parallelism_config.cp_enabled * add parallelism_config to trainer state * warn when auto-enabling FSDP * fix some reviews * WIP: somewhat matching loss * Feat: add back nested_gather * Feat: cleanup * Fix: raise on non-sdpa attn * remove context_parallel_size from TrainingArguments * if we have parallelism_config, we defer to get_state_dict from accelerate * fix form review * Feat: add parallelism config support * Chore: revert some unwanted formatting changes * Fix: check None * Check none 2 * Fix: remove duplicate import * Update src/transformers/trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Fin * require accerelate 1.10.1 and higer --------- Co-authored-by: S1ro1 <matej.sirovatka@gmail.com> Co-authored-by: Matej Sirovatka <54212263+S1ro1@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
1 parent 63caaea commit 6d2bb1e

File tree

2 files changed

+200
-51
lines changed

2 files changed

+200
-51
lines changed

src/transformers/trainer.py

Lines changed: 188 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3823,6 +3823,123 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
38233823

38243824
return inputs
38253825

3826+
def _is_attention_mask_causal(self, attention_mask):
3827+
"""
3828+
Check if an attention mask is causal (compatible with causal attention).
3829+
Context parallelism only supports causal attention patterns. This function
3830+
checks if the provided attention mask is compatible.
3831+
3832+
Args:
3833+
attention_mask (torch.Tensor): The attention mask to check
3834+
3835+
Returns:
3836+
bool: True if the mask is causal or compatible with causal attention
3837+
"""
3838+
if attention_mask is None:
3839+
return True # No mask is considered causal (model uses default causal masking)
3840+
3841+
# Handle different mask dimensions
3842+
if attention_mask.dim() == 2:
3843+
# (batch_size, seq_len) - standard padding mask, compatible with causal attention
3844+
return True
3845+
elif attention_mask.dim() in [3, 4]:
3846+
# (batch_size, seq_len, seq_len) or (batch_size, num_heads, seq_len, seq_len)
3847+
# Check if it's lower triangular (causal)
3848+
seq_len = attention_mask.shape[-1]
3849+
if seq_len <= 1:
3850+
return True # Single token or empty is always causal
3851+
3852+
# Take first batch and head (if 4D) for checking pattern
3853+
if attention_mask.dim() == 4:
3854+
mask = attention_mask[0, 0] # First batch, first head
3855+
else:
3856+
mask = attention_mask[0] # First batch
3857+
3858+
# Check if upper triangular part is masked (should be 0 or very negative for causal)
3859+
upper_triangular = torch.triu(mask, diagonal=1)
3860+
3861+
# For causal masks, upper triangular should be 0 or very negative (like -inf)
3862+
# Use a reasonable threshold to handle float precision issues
3863+
is_causal = torch.all(upper_triangular <= 1e-6) or torch.all(upper_triangular < -1e4)
3864+
return is_causal.item() if isinstance(is_causal, torch.Tensor) else is_causal
3865+
3866+
# For unknown dimensions, be conservative and reject
3867+
return False
3868+
3869+
def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch.Tensor, Any]]):
3870+
"""
3871+
Prepare inputs for context parallelism by setting up buffers and validation.
3872+
3873+
Args:
3874+
model: The model being trained
3875+
inputs: Input tensors to prepare
3876+
3877+
Returns:
3878+
tuple: (context_manager, prepared_inputs) where context_manager is either
3879+
the context parallelism wrapper or a no-op context
3880+
"""
3881+
if (
3882+
getattr(self.accelerator, "parallelism_config", None) is not None
3883+
and self.accelerator.parallelism_config.cp_enabled
3884+
):
3885+
if hasattr(model, "config"):
3886+
if model.config._attn_implementation != "sdpa":
3887+
raise ValueError(
3888+
f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}."
3889+
)
3890+
3891+
if "position_ids" not in inputs:
3892+
logger.warning_once("Position IDs not found in the inputs, generating manually")
3893+
inputs["position_ids"] = torch.arange(
3894+
inputs["input_ids"].size(1), device=inputs["input_ids"].device
3895+
).expand(inputs["input_ids"].size(0), -1)
3896+
if "shift_labels" not in inputs:
3897+
logger.warning_once("Shift labels not found in the inputs, shifting manually")
3898+
if "labels" in inputs:
3899+
_ignore_index = -100
3900+
labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index)
3901+
inputs["shift_labels"] = labels[:, 1:].contiguous()
3902+
3903+
buffers = []
3904+
buffer_seq_dims = []
3905+
3906+
if "input_ids" in inputs:
3907+
buffers.append(inputs["input_ids"])
3908+
buffer_seq_dims.append(1) # Sequence dimension
3909+
if "labels" in inputs:
3910+
buffers.append(inputs["labels"])
3911+
buffer_seq_dims.append(1)
3912+
if "shift_labels" in inputs:
3913+
buffers.append(inputs["shift_labels"])
3914+
buffer_seq_dims.append(1)
3915+
if "attention_mask" in inputs and not getattr(self, "_attn_mask_causal_checked", False):
3916+
# Context parallel currently doesn't support other masks than causal
3917+
# Accelerate applies hooks to replace mask with is_causal arg in SDPA
3918+
# Check if the mask is really causal and if not throw an error
3919+
# TODO: check this only once or always, with speed being the cost
3920+
attention_mask = inputs["attention_mask"]
3921+
if not self._is_attention_mask_causal(attention_mask):
3922+
raise ValueError(
3923+
"Context parallelism only supports causal attention masks. "
3924+
"The provided attention_mask is not causal. "
3925+
"Please ensure your data uses causal masking (lower triangular) "
3926+
"or remove the attention_mask to use the model's default causal masking."
3927+
)
3928+
self._attn_mask_causal_checked = True
3929+
# Include position_ids in context parallelism splitting
3930+
if "position_ids" in inputs and inputs["position_ids"] is not None:
3931+
buffers.append(inputs["position_ids"])
3932+
buffer_seq_dims.append(1)
3933+
3934+
return partial(
3935+
self.accelerator.maybe_context_parallel,
3936+
buffers=buffers,
3937+
buffer_seq_dims=buffer_seq_dims,
3938+
no_restore_buffers=set(buffers),
3939+
), inputs
3940+
3941+
return contextlib.nullcontext, inputs
3942+
38263943
def compute_loss_context_manager(self):
38273944
"""
38283945
A helper wrapper to group together context managers.
@@ -3873,66 +3990,74 @@ def training_step(
38733990
Return:
38743991
`torch.Tensor`: The tensor with training loss on this batch.
38753992
"""
3876-
model.train()
3877-
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
3878-
self.optimizer.train()
3993+
# Prepare buffers for context parallelism
38793994

3880-
inputs = self._prepare_inputs(inputs)
3881-
if is_sagemaker_mp_enabled():
3882-
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
3883-
return loss_mb.reduce_mean().detach().to(self.args.device)
3995+
cp_context, inputs = self._prepare_context_parallel_inputs(model, inputs)
38843996

3885-
with self.compute_loss_context_manager():
3886-
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
3997+
# Context manager is no-op if CP isn't enabled
3998+
with cp_context():
3999+
model.train()
4000+
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
4001+
self.optimizer.train()
38874002

3888-
del inputs
3889-
if (
3890-
self.args.torch_empty_cache_steps is not None
3891-
and self.state.global_step % self.args.torch_empty_cache_steps == 0
3892-
):
3893-
if is_torch_xpu_available():
3894-
torch.xpu.empty_cache()
3895-
elif is_torch_mlu_available():
3896-
torch.mlu.empty_cache()
3897-
elif is_torch_musa_available():
3898-
torch.musa.empty_cache()
3899-
elif is_torch_npu_available():
3900-
torch.npu.empty_cache()
3901-
elif is_torch_mps_available():
3902-
torch.mps.empty_cache()
3903-
elif is_torch_hpu_available():
3904-
logger.warning(
3905-
"`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
3906-
)
3907-
else:
3908-
torch.cuda.empty_cache()
4003+
inputs = self._prepare_inputs(inputs)
4004+
if is_sagemaker_mp_enabled():
4005+
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
4006+
return loss_mb.reduce_mean().detach().to(self.args.device)
39094007

3910-
kwargs = {}
4008+
with self.compute_loss_context_manager():
4009+
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
39114010

3912-
# For LOMO optimizers you need to explicitly use the learning rate
3913-
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
3914-
kwargs["learning_rate"] = self._get_learning_rate()
4011+
del inputs
4012+
if (
4013+
self.args.torch_empty_cache_steps is not None
4014+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
4015+
):
4016+
if is_torch_xpu_available():
4017+
torch.xpu.empty_cache()
4018+
elif is_torch_mlu_available():
4019+
torch.mlu.empty_cache()
4020+
elif is_torch_musa_available():
4021+
torch.musa.empty_cache()
4022+
elif is_torch_npu_available():
4023+
torch.npu.empty_cache()
4024+
elif is_torch_mps_available():
4025+
torch.mps.empty_cache()
4026+
elif is_torch_hpu_available():
4027+
logger.warning(
4028+
"`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
4029+
)
4030+
else:
4031+
torch.cuda.empty_cache()
39154032

3916-
if self.args.n_gpu > 1:
3917-
loss = loss.mean() # mean() to average on multi-gpu parallel training
4033+
kwargs = {}
39184034

3919-
if self.use_apex:
3920-
from apex import amp
4035+
# For LOMO optimizers you need to explicitly use the learning rate
4036+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
4037+
kwargs["learning_rate"] = self._get_learning_rate()
39214038

3922-
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
3923-
scaled_loss.backward()
3924-
else:
3925-
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
3926-
if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
3927-
# If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
3928-
loss = loss / self.current_gradient_accumulation_steps
4039+
if self.args.n_gpu > 1:
4040+
loss = loss.mean() # mean() to average on multi-gpu parallel training
4041+
4042+
if self.use_apex:
4043+
from apex import amp
4044+
4045+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
4046+
scaled_loss.backward()
4047+
else:
4048+
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
4049+
if (
4050+
not self.model_accepts_loss_kwargs or num_items_in_batch is None
4051+
) and self.compute_loss_func is None:
4052+
# If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
4053+
loss = loss / self.current_gradient_accumulation_steps
39294054

3930-
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
3931-
# https://github.com/huggingface/transformers/pull/35808
3932-
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
3933-
kwargs["scale_wrt_gas"] = False
4055+
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
4056+
# https://github.com/huggingface/transformers/pull/35808
4057+
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
4058+
kwargs["scale_wrt_gas"] = False
39344059

3935-
self.accelerator.backward(loss, **kwargs)
4060+
self.accelerator.backward(loss, **kwargs)
39364061

39374062
return loss.detach()
39384063

@@ -5337,6 +5462,16 @@ def create_accelerator_and_postprocess(self):
53375462
args = {
53385463
"deepspeed_plugin": self.args.deepspeed_plugin,
53395464
}
5465+
5466+
# We defer compatibility checks to accelerator
5467+
if self.args.parallelism_config is not None:
5468+
if not is_accelerate_available("1.10.1"):
5469+
raise ImportError(
5470+
"ParallelismConfig requires accelerate v1.10.1 and above. Please upgrade accelerate to use this feature."
5471+
)
5472+
5473+
args["parallelism_config"] = self.args.parallelism_config
5474+
53405475
if is_accelerate_available("0.28.0"):
53415476
args["dataloader_config"] = dataloader_config
53425477
else:
@@ -5479,6 +5614,9 @@ def get_batch_samples(
54795614
if self.args.n_gpu > 1 and num_items_in_batch.dim() == 0:
54805615
# In the DataParallel case, convert the scalar tensor into a 1-dim tensor
54815616
num_items_in_batch = num_items_in_batch.unsqueeze(0)
5617+
# Divide by number of devices with the same batch
5618+
if pc := self.accelerator.parallelism_config:
5619+
num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size
54825620

54835621
return batch_samples, num_items_in_batch
54845622

src/transformers/training_args.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@
7777

7878
from .trainer_pt_utils import AcceleratorConfig
7979

80+
if is_accelerate_available("1.10.1"):
81+
from accelerate.parallelism_config import ParallelismConfig
82+
8083
if is_torch_xla_available():
8184
import torch_xla.core.xla_model as xm
8285

@@ -597,7 +600,8 @@ class TrainingArguments:
597600
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`.
598601
If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues
599602
with hyperparameter tuning.
600-
603+
parallelism_config (`ParallelismConfig`, *optional*):
604+
Parallelism configuration for the training run. Requires Accelerate `1.10.1`
601605
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
602606
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
603607
labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
@@ -1272,6 +1276,10 @@ class TrainingArguments:
12721276
)
12731277
},
12741278
)
1279+
parallelism_config: Optional["ParallelismConfig"] = field(
1280+
default=None,
1281+
metadata={"help": ("Parallelism configuration for the training run. Requires Accelerate `1.10.1`")},
1282+
)
12751283
deepspeed: Optional[Union[dict, str]] = field(
12761284
default=None,
12771285
metadata={
@@ -2561,6 +2569,9 @@ def to_dict(self):
25612569
quantization_config = v.get("quantization_config")
25622570
if quantization_config and not isinstance(quantization_config, dict):
25632571
d[k]["quantization_config"] = quantization_config.to_dict()
2572+
if k == "parallelism_config" and v is not None:
2573+
d[k] = v.to_json()
2574+
25642575
self._dict_dtype_to_str(d)
25652576

25662577
return d

0 commit comments

Comments
 (0)