-
Notifications
You must be signed in to change notification settings - Fork 961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support tensor parallel using Pytorch 2.0 & Data loader #3173
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -713,6 +713,7 @@ def __init__( | |
_drop_last: bool = False, | ||
_non_blocking: bool = False, | ||
slice_fn=None, | ||
torch_device_mesh=None, | ||
**kwargs, | ||
): | ||
shuffle = False | ||
|
@@ -732,6 +733,7 @@ def __init__( | |
self._drop_last = _drop_last | ||
self._non_blocking = _non_blocking | ||
self.skip_batches = skip_batches | ||
self.torch_device_mesh = torch_device_mesh | ||
|
||
self.slice_fn = slice_tensors if slice_fn is None else slice_fn | ||
self.iteration = 0 | ||
|
@@ -740,7 +742,29 @@ def _fetch_batches(self, iterator): | |
batches, batch = None, None | ||
# On process 0, we gather the batch to dispatch. | ||
if self.state.process_index == 0: | ||
# if a device mesh is provided extract each dimension (tp and dp) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we perform this check already during |
||
# device mesh will be used only if there is tp involved | ||
# otherwise the default behavour should be sufficient | ||
submesh_tp = None | ||
submesh_dp = None | ||
if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names: | ||
# extract torch sub device mesh objects | ||
submesh_tp = self.torch_device_mesh["tp"] | ||
if "dp" in self.torch_device_mesh.mesh_dim_names: | ||
submesh_dp = self.torch_device_mesh["dp"] | ||
|
||
if submesh_tp and submesh_dp: | ||
raise ValueError("TP + DDP / TP + FSDP is not yet supported") | ||
|
||
# Procedure to support TP only is simpler | ||
# since we want to dispatch the same batch of samples across all ranks | ||
# this removes complexity of handling multiple tp rank groups when TP + DP | ||
# combination is involved. | ||
|
||
try: | ||
# for TP case avoid using split_batches | ||
# since it would mean that the dataloader should be spilling out | ||
# duplicates of batches. | ||
if self.split_batches: | ||
# One batch of the main iterator is dispatched and split. | ||
self._update_state_dict() | ||
|
@@ -749,9 +773,15 @@ def _fetch_batches(self, iterator): | |
# num_processes batches of the main iterator are concatenated then dispatched and split. | ||
# We add the batches one by one so we have the remainder available when drop_last=False. | ||
batches = [] | ||
for _ in range(self.state.num_processes): | ||
if submesh_tp: | ||
# when tp, extract single batch and then replicate | ||
self._update_state_dict() | ||
batches.append(next(iterator)) | ||
batch = next(iterator) | ||
batches = [batch] * self.state.num_processes | ||
else: | ||
for _ in range(self.state.num_processes): | ||
self._update_state_dict() | ||
batches.append(next(iterator)) | ||
try: | ||
batch = concatenate(batches, dim=0) | ||
except RuntimeError as e: | ||
|
@@ -942,6 +972,7 @@ def prepare_data_loader( | |
data_seed: Optional[int] = None, | ||
non_blocking: bool = False, | ||
use_stateful_dataloader: bool = False, | ||
torch_device_mesh=None | ||
) -> DataLoader: | ||
""" | ||
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. | ||
|
@@ -1009,6 +1040,8 @@ def prepare_data_loader( | |
"If set to true, the dataloader prepared by the Accelerator will be backed by " | ||
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). | ||
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." | ||
torch_device_mesh (``, *optional*, defaults to `None`): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type is missing. |
||
PyTorch device mesh. | ||
|
||
|
||
Returns: | ||
|
@@ -1144,6 +1177,7 @@ def prepare_data_loader( | |
_non_blocking=non_blocking, | ||
slice_fn=slice_fn_for_dispatch, | ||
use_stateful_dataloader=use_stateful_dataloader, | ||
torch_device_mesh=torch_device_mesh, | ||
**kwargs, | ||
) | ||
elif sampler_is_batch_sampler: | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -540,6 +540,7 @@ class DistributedType(str, enum.Enum): | |||||
MULTI_XPU = "MULTI_XPU" | ||||||
DEEPSPEED = "DEEPSPEED" | ||||||
FSDP = "FSDP" | ||||||
TP = "TP" | ||||||
XLA = "XLA" | ||||||
MEGATRON_LM = "MEGATRON_LM" | ||||||
|
||||||
|
@@ -1810,6 +1811,31 @@ def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=F | |||||
self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy) | ||||||
|
||||||
|
||||||
@dataclass | ||||||
class TorchTensorParallelPlugin: | ||||||
""" | ||||||
This plugin is used to enable tensor parallelism using PyTorch 2.0. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a reference to further information on TP should be added here. Also:
Suggested change
|
||||||
""" | ||||||
|
||||||
tp_size: int = field( | ||||||
default=1, | ||||||
metadata={ | ||||||
"help": "tensor parallel size will be used in the device mesh preparation" | ||||||
}, | ||||||
) | ||||||
|
||||||
# type has to be "torch.distributed.DeviceMesh" | ||||||
torch_device_mesh: torch.distributed.DeviceMesh = field( | ||||||
default=None | ||||||
) | ||||||
|
||||||
def __post_init__(self): | ||||||
from torch.distributed.device_mesh import init_device_mesh | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we perform a check on the minimum PyTorch and transformers versions? Not sure if here is the best place or somewhere else, Zach? |
||||||
mesh_dim_name = "tp" | ||||||
device = "cuda" # support for other devices has to be investigated | ||||||
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,)) | ||||||
|
||||||
|
||||||
@dataclass | ||||||
class MegatronLMPlugin: | ||||||
""" | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apply_tensor_parallel
will be implemented in huggingface/transformers#34194 but only for select model architectures, right? Should we check this and if not present, raise an appropriate error?