Skip to content

Commit 67fd36f

Browse files
S1ro1SunMarc
authored andcommitted
PATCH: add back n-dim device-mesh + fix tp trainer saving (#39693)
* Feat: something * Feat: initial changes * tmp changes to unblock * Refactor * remove todo * Feat: docstring * Fix: saving of distributed model in trainer * Fix: distributed saving with trainer * Feat: add pure tp saving * Only require tp dim if ndim > 1 * Fix: default to None * Fix: better comments/errors * Fix: properly check tp_size attribute * Fix: properly check for None in tp_size --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
1 parent 709c6fd commit 67fd36f

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/transformers/modeling_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4472,7 +4472,7 @@ def from_pretrained(
44724472
A torch tensor parallel degree. If not provided would default to world size.
44734473
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
44744474
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
4475-
If provided, it has to contain dimension named `"tp"` which will be used for tensor parallelism
4475+
If provided, it has to contain dimension named `"tp"` in case it's > 1 dimensional, this dimension will be used for tensor parallelism
44764476
offload_folder (`str` or `os.PathLike`, *optional*):
44774477
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
44784478
offload_state_dict (`bool`, *optional*):
@@ -4617,10 +4617,15 @@ def from_pretrained(
46174617
if device_mesh is None:
46184618
tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
46194619
else:
4620-
# TODO: make device_mesh support multiple dimensions
46214620
if device_mesh.ndim > 1:
4622-
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
4623-
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
4621+
if "tp" not in device_mesh.mesh_dim_names:
4622+
raise ValueError(
4623+
"When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
4624+
"Please provide a valid `device_mesh`."
4625+
)
4626+
device_mesh = device_mesh["tp"]
4627+
tp_size = device_mesh.size()
4628+
device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
46244629

46254630
if tp_size is None:
46264631
tp_size = torch.distributed.get_world_size()

src/transformers/trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3953,6 +3953,13 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
39533953
if IS_SAGEMAKER_MP_POST_1_10:
39543954
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
39553955
Path(os.path.join(output_dir, "user_content.pt")).touch()
3956+
# We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank
3957+
elif getattr(self.accelerator, "parallelism_config", None) is not None:
3958+
if self.accelerator.should_save_model:
3959+
self._save(output_dir)
3960+
# If we drop to here, we're in 1D parallelism, so all ranks need to go to `save_pretrained`
3961+
elif (tp_size := getattr(self.model, "_tp_size", 0)) is not None and tp_size > 1:
3962+
self._save(output_dir)
39563963
elif self.is_fsdp_enabled:
39573964
if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and (
39583965
version.parse(accelerate_version) > version.parse("0.24.1")

0 commit comments

Comments
 (0)