Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ca8e366
initial context_parallel_size support in trainer
kashif Aug 15, 2025
29b22cf
Merge branch 'main' into trainer-cp
kashif Aug 15, 2025
d70fef4
For context parallelism, use AVG instead of SUM to avoid over-account…
kashif Aug 15, 2025
dd58dd0
git pushMerge branch 'trainer-cp' of https://github.com/huggingface/t…
kashif Aug 15, 2025
ecc2366
use parallelism_config.cp_enabled
kashif Aug 15, 2025
a629ff0
add parallelism_config to trainer state
kashif Aug 15, 2025
66d4273
warn when auto-enabling FSDP
kashif Aug 15, 2025
ffa4699
fix some reviews
kashif Aug 15, 2025
361f122
WIP: somewhat matching loss
S1ro1 Aug 18, 2025
3d426c1
Merge branch 'main' into trainer-cp
kashif Aug 18, 2025
3efe69b
Merge branch 'main' into trainer-cp
kashif Aug 19, 2025
eca52ac
Feat: add back nested_gather
S1ro1 Aug 19, 2025
951527b
Feat: cleanup
S1ro1 Aug 19, 2025
412c15e
Fix: raise on non-sdpa attn
S1ro1 Aug 19, 2025
be60c40
Merge branch 'main' into trainer-cp
S1ro1 Aug 19, 2025
2c357aa
remove context_parallel_size from TrainingArguments
kashif Aug 19, 2025
71e082f
if we have parallelism_config, we defer to get_state_dict from accele…
kashif Aug 20, 2025
37e6fdf
Merge branch 'main' into trainer-cp
kashif Aug 20, 2025
4f6fe15
Merge branch 'main' into trainer-cp
kashif Aug 21, 2025
485d7fa
fix form review
kashif Aug 22, 2025
3d16def
Feat: add parallelism config support
S1ro1 Aug 22, 2025
25a308e
Chore: revert some unwanted formatting changes
S1ro1 Aug 22, 2025
6d41365
Fix: check None
S1ro1 Aug 22, 2025
d82022c
Check none 2
S1ro1 Aug 22, 2025
ae9f878
Fix: remove duplicate import
S1ro1 Aug 22, 2025
531924e
Merge branch 'main' into trainer-cp
S1ro1 Aug 22, 2025
64d7336
Merge branch 'main' into trainer-cp
SunMarc Aug 22, 2025
52cb3bc
Update src/transformers/trainer.py
S1ro1 Aug 22, 2025
6e9fb30
Update src/transformers/training_args.py
S1ro1 Aug 22, 2025
bf187b2
Merge branch 'main' into trainer-cp
kashif Aug 25, 2025
33817f3
Fin
S1ro1 Aug 25, 2025
2294506
require accerelate 1.10.1 and higer
kashif Aug 25, 2025
f956348
Merge branch 'main' into trainer-cp
SunMarc Aug 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 188 additions & 50 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3823,6 +3823,123 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s

return inputs

def _is_attention_mask_causal(self, attention_mask):
"""
Check if an attention mask is causal (compatible with causal attention).
Context parallelism only supports causal attention patterns. This function
checks if the provided attention mask is compatible.

Args:
attention_mask (torch.Tensor): The attention mask to check

Returns:
bool: True if the mask is causal or compatible with causal attention
"""
if attention_mask is None:
return True # No mask is considered causal (model uses default causal masking)

# Handle different mask dimensions
if attention_mask.dim() == 2:
# (batch_size, seq_len) - standard padding mask, compatible with causal attention
return True
elif attention_mask.dim() in [3, 4]:
# (batch_size, seq_len, seq_len) or (batch_size, num_heads, seq_len, seq_len)
# Check if it's lower triangular (causal)
seq_len = attention_mask.shape[-1]
if seq_len <= 1:
return True # Single token or empty is always causal

# Take first batch and head (if 4D) for checking pattern
if attention_mask.dim() == 4:
mask = attention_mask[0, 0] # First batch, first head
else:
mask = attention_mask[0] # First batch

# Check if upper triangular part is masked (should be 0 or very negative for causal)
upper_triangular = torch.triu(mask, diagonal=1)

# For causal masks, upper triangular should be 0 or very negative (like -inf)
# Use a reasonable threshold to handle float precision issues
is_causal = torch.all(upper_triangular <= 1e-6) or torch.all(upper_triangular < -1e4)
return is_causal.item() if isinstance(is_causal, torch.Tensor) else is_causal

# For unknown dimensions, be conservative and reject
return False

def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch.Tensor, Any]]):
"""
Prepare inputs for context parallelism by setting up buffers and validation.

Args:
model: The model being trained
inputs: Input tensors to prepare

Returns:
tuple: (context_manager, prepared_inputs) where context_manager is either
the context parallelism wrapper or a no-op context
"""
if (
getattr(self.accelerator, "parallelism_config", None) is not None
and self.accelerator.parallelism_config.cp_enabled
):
if hasattr(model, "config"):
if model.config._attn_implementation != "sdpa":
raise ValueError(
f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}."
)

if "position_ids" not in inputs:
logger.warning_once("Position IDs not found in the inputs, generating manually")
inputs["position_ids"] = torch.arange(
inputs["input_ids"].size(1), device=inputs["input_ids"].device
).expand(inputs["input_ids"].size(0), -1)
if "shift_labels" not in inputs:
logger.warning_once("Shift labels not found in the inputs, shifting manually")
if "labels" in inputs:
_ignore_index = -100
labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index)
inputs["shift_labels"] = labels[:, 1:].contiguous()

buffers = []
buffer_seq_dims = []

if "input_ids" in inputs:
buffers.append(inputs["input_ids"])
buffer_seq_dims.append(1) # Sequence dimension
if "labels" in inputs:
buffers.append(inputs["labels"])
buffer_seq_dims.append(1)
if "shift_labels" in inputs:
buffers.append(inputs["shift_labels"])
buffer_seq_dims.append(1)
if "attention_mask" in inputs and not getattr(self, "_attn_mask_causal_checked", False):
# Context parallel currently doesn't support other masks than causal
# Accelerate applies hooks to replace mask with is_causal arg in SDPA
# Check if the mask is really causal and if not throw an error
# TODO: check this only once or always, with speed being the cost
attention_mask = inputs["attention_mask"]
if not self._is_attention_mask_causal(attention_mask):
raise ValueError(
"Context parallelism only supports causal attention masks. "
"The provided attention_mask is not causal. "
"Please ensure your data uses causal masking (lower triangular) "
"or remove the attention_mask to use the model's default causal masking."
)
self._attn_mask_causal_checked = True
# Include position_ids in context parallelism splitting
if "position_ids" in inputs and inputs["position_ids"] is not None:
buffers.append(inputs["position_ids"])
buffer_seq_dims.append(1)

return partial(
self.accelerator.maybe_context_parallel,
buffers=buffers,
buffer_seq_dims=buffer_seq_dims,
no_restore_buffers=set(buffers),
), inputs

return contextlib.nullcontext, inputs

def compute_loss_context_manager(self):
"""
A helper wrapper to group together context managers.
Expand Down Expand Up @@ -3873,66 +3990,74 @@ def training_step(
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()
# Prepare buffers for context parallelism

inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
cp_context, inputs = self._prepare_context_parallel_inputs(model, inputs)

with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
# Context manager is no-op if CP isn't enabled
with cp_context():
model.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()

del inputs
if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available():
torch.mps.empty_cache()
elif is_torch_hpu_available():
logger.warning(
"`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
)
else:
torch.cuda.empty_cache()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)

kwargs = {}
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

# For LOMO optimizers you need to explicitly use the learning rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
del inputs
if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available():
torch.mps.empty_cache()
elif is_torch_hpu_available():
logger.warning(
"`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
)
else:
torch.cuda.empty_cache()

if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
kwargs = {}

if self.use_apex:
from apex import amp
# For LOMO optimizers you need to explicitly use the learning rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()

with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
# If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
loss = loss / self.current_gradient_accumulation_steps
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

if self.use_apex:
from apex import amp

with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
if (
not self.model_accepts_loss_kwargs or num_items_in_batch is None
) and self.compute_loss_func is None:
# If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
loss = loss / self.current_gradient_accumulation_steps

# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
# https://github.com/huggingface/transformers/pull/35808
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs["scale_wrt_gas"] = False
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
# https://github.com/huggingface/transformers/pull/35808
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs["scale_wrt_gas"] = False

self.accelerator.backward(loss, **kwargs)
self.accelerator.backward(loss, **kwargs)

return loss.detach()

Expand Down Expand Up @@ -5337,6 +5462,16 @@ def create_accelerator_and_postprocess(self):
args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
}

# We defer compatibility checks to accelerator
if self.args.parallelism_config is not None:
if not is_accelerate_available("1.10.1"):
raise ImportError(
"ParallelismConfig requires accelerate v1.10.1 and above. Please upgrade accelerate to use this feature."
)

args["parallelism_config"] = self.args.parallelism_config

if is_accelerate_available("0.28.0"):
args["dataloader_config"] = dataloader_config
else:
Expand Down Expand Up @@ -5479,6 +5614,9 @@ def get_batch_samples(
if self.args.n_gpu > 1 and num_items_in_batch.dim() == 0:
# In the DataParallel case, convert the scalar tensor into a 1-dim tensor
num_items_in_batch = num_items_in_batch.unsqueeze(0)
# Divide by number of devices with the same batch
if pc := self.accelerator.parallelism_config:
num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size

return batch_samples, num_items_in_batch

Expand Down
13 changes: 12 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@

from .trainer_pt_utils import AcceleratorConfig

if is_accelerate_available("1.10.1"):
from accelerate.parallelism_config import ParallelismConfig

if is_torch_xla_available():
import torch_xla.core.xla_model as xm

Expand Down Expand Up @@ -597,7 +600,8 @@ class TrainingArguments:
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`.
If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues
with hyperparameter tuning.

parallelism_config (`ParallelismConfig`, *optional*):
Parallelism configuration for the training run. Requires Accelerate `1.10.1`
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
Expand Down Expand Up @@ -1272,6 +1276,10 @@ class TrainingArguments:
)
},
)
parallelism_config: Optional["ParallelismConfig"] = field(
default=None,
metadata={"help": ("Parallelism configuration for the training run. Requires Accelerate `1.10.1`")},
)
deepspeed: Optional[Union[dict, str]] = field(
default=None,
metadata={
Expand Down Expand Up @@ -2561,6 +2569,9 @@ def to_dict(self):
quantization_config = v.get("quantization_config")
if quantization_config and not isinstance(quantization_config, dict):
d[k]["quantization_config"] = quantization_config.to_dict()
if k == "parallelism_config" and v is not None:
d[k] = v.to_json()

self._dict_dtype_to_str(d)

return d
Expand Down