Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support tensor parallel using Pytorch 2.0 & Data loader #3173

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
ProjectConfiguration,
RNGType,
TorchDynamoPlugin,
TorchTensorParallelPlugin,
apply_fp8_autowrap,
check_os_kernel,
clean_state_dict_for_safetensors,
Expand Down Expand Up @@ -187,6 +188,9 @@ class Accelerator:
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):
Tweak your torch tensor parallel. This argument is optional and can be configured directly
using *accelerate config*
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
Expand Down Expand Up @@ -253,6 +257,7 @@ def __init__(
dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
torch_tp_plugin: TorchTensorParallelPlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
Expand Down Expand Up @@ -417,6 +422,7 @@ def __init__(
dynamo_plugin=dynamo_plugin,
deepspeed_plugin=deepspeed_plugins,
fsdp_plugin=fsdp_plugin,
torch_tp_plugin=torch_tp_plugin,
megatron_lm_plugin=megatron_lm_plugin,
_from_accelerator=True,
**kwargs,
Expand Down Expand Up @@ -1457,6 +1463,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
model.apply_tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_tensor_parallel will be implemented in huggingface/transformers#34194 but only for select model architectures, right? Should we check this and if not present, raise an appropriate error?

elif self.distributed_type == DistributedType.FSDP:
# We need to fix the optimizer *before* sharding the model
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -2104,6 +2112,7 @@ def prepare_data_loader(
data_seed=self.dataloader_config.data_seed,
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
torch_device_mesh=self.state.torch_tp_plugin.torch_device_mesh if self.state.torch_tp_plugin else None,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
38 changes: 36 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def __init__(
_drop_last: bool = False,
_non_blocking: bool = False,
slice_fn=None,
torch_device_mesh=None,
**kwargs,
):
shuffle = False
Expand All @@ -732,6 +733,7 @@ def __init__(
self._drop_last = _drop_last
self._non_blocking = _non_blocking
self.skip_batches = skip_batches
self.torch_device_mesh = torch_device_mesh

self.slice_fn = slice_tensors if slice_fn is None else slice_fn
self.iteration = 0
Expand All @@ -740,7 +742,29 @@ def _fetch_batches(self, iterator):
batches, batch = None, None
# On process 0, we gather the batch to dispatch.
if self.state.process_index == 0:
# if a device mesh is provided extract each dimension (tp and dp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we perform this check already during __init__? submesh_tp and submesh_dp could be stored as self attributes.

# device mesh will be used only if there is tp involved
# otherwise the default behavour should be sufficient
submesh_tp = None
submesh_dp = None
if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
# extract torch sub device mesh objects
submesh_tp = self.torch_device_mesh["tp"]
if "dp" in self.torch_device_mesh.mesh_dim_names:
submesh_dp = self.torch_device_mesh["dp"]

if submesh_tp and submesh_dp:
raise ValueError("TP + DDP / TP + FSDP is not yet supported")

# Procedure to support TP only is simpler
# since we want to dispatch the same batch of samples across all ranks
# this removes complexity of handling multiple tp rank groups when TP + DP
# combination is involved.

try:
# for TP case avoid using split_batches
# since it would mean that the dataloader should be spilling out
# duplicates of batches.
if self.split_batches:
# One batch of the main iterator is dispatched and split.
self._update_state_dict()
Expand All @@ -749,9 +773,15 @@ def _fetch_batches(self, iterator):
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
if submesh_tp:
# when tp, extract single batch and then replicate
self._update_state_dict()
batches.append(next(iterator))
batch = next(iterator)
batches = [batch] * self.state.num_processes
else:
for _ in range(self.state.num_processes):
self._update_state_dict()
batches.append(next(iterator))
try:
batch = concatenate(batches, dim=0)
except RuntimeError as e:
Expand Down Expand Up @@ -942,6 +972,7 @@ def prepare_data_loader(
data_seed: Optional[int] = None,
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
torch_device_mesh=None
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -1009,6 +1040,8 @@ def prepare_data_loader(
"If set to true, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
torch_device_mesh (``, *optional*, defaults to `None`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type is missing.

PyTorch device mesh.


Returns:
Expand Down Expand Up @@ -1144,6 +1177,7 @@ def prepare_data_loader(
_non_blocking=non_blocking,
slice_fn=slice_fn_for_dispatch,
use_stateful_dataloader=use_stateful_dataloader,
torch_device_mesh=torch_device_mesh,
**kwargs,
)
elif sampler_is_batch_sampler:
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ def __init__(
dynamo_plugin=None,
deepspeed_plugin=None,
fsdp_plugin=None,
torch_tp_plugin=None,
megatron_lm_plugin=None,
_from_accelerator: bool = False,
**kwargs,
Expand All @@ -864,6 +865,7 @@ def __init__(
if not self.initialized:
self.deepspeed_plugins = None
self.use_ipex = None
self.torch_tp_plugin = torch_tp_plugin
mixed_precision = (
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
if mixed_precision is None
Expand Down Expand Up @@ -921,6 +923,8 @@ def __init__(
self.distributed_type = DistributedType.MEGATRON_LM
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
self.megatron_lm_plugin = megatron_lm_plugin
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
self.distributed_type = DistributedType.TP
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
if is_ipex_available():
# check if user disables it explicitly
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
SageMakerDistributedType,
TensorInformation,
TorchDynamoPlugin,
TorchTensorParallelPlugin,
add_model_config_to_megatron_parser,
)
from .environment import (
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"master_port",
]

CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM"]
CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"]
TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [
"MULTI_NPU",
"MULTI_MLU",
Expand Down
26 changes: 26 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ class DistributedType(str, enum.Enum):
MULTI_XPU = "MULTI_XPU"
DEEPSPEED = "DEEPSPEED"
FSDP = "FSDP"
TP = "TP"
XLA = "XLA"
MEGATRON_LM = "MEGATRON_LM"

Expand Down Expand Up @@ -1810,6 +1811,31 @@ def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=F
self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy)


@dataclass
class TorchTensorParallelPlugin:
"""
This plugin is used to enable tensor parallelism using PyTorch 2.0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a reference to further information on TP should be added here. Also:

Suggested change
This plugin is used to enable tensor parallelism using PyTorch 2.0.
This plugin is used to enable tensor parallelism using PyTorch >= 2.0.

"""

tp_size: int = field(
default=1,
metadata={
"help": "tensor parallel size will be used in the device mesh preparation"
},
)

# type has to be "torch.distributed.DeviceMesh"
torch_device_mesh: torch.distributed.DeviceMesh = field(
default=None
)

def __post_init__(self):
from torch.distributed.device_mesh import init_device_mesh
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we perform a check on the minimum PyTorch and transformers versions? Not sure if here is the best place or somewhere else, Zach?

mesh_dim_name = "tp"
device = "cuda" # support for other devices has to be investigated
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))


@dataclass
class MegatronLMPlugin:
"""
Expand Down
Loading