Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] support Distributed Data Parallel #137

Merged
merged 32 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f04471e
initial data_parallel code based on colossalai code but it needs to b…
Jan 29, 2023
b15e42c
take out the code from _coloddp, _coloddp will be removed soon
Feb 12, 2023
0aa7020
initialization for colossalAI integration
jinwonkim93 Feb 18, 2023
3eba4a1
working code
jinwonkim93 Feb 20, 2023
f89e22a
Change to oslo interface
jinwonkim93 Feb 22, 2023
d45e029
remove temp testcode
jinwonkim93 Feb 22, 2023
80f0970
change docstrings
jinwonkim93 Feb 22, 2023
315e5f8
Merge branch 'data_parallel' into distributed_data_parallel
nijkah Mar 1, 2023
e9f8b0b
reformat all files
jinwonkim93 Mar 2, 2023
b49092c
[Refactor] Refactor backward interface in DP (#141)
nijkah Mar 2, 2023
20c0d10
[Clean] Clean DDP Code (#142)
nijkah Mar 2, 2023
5079773
Merge branch 'main' into distributed_data_parallel
nijkah Mar 2, 2023
f9097f0
Update oslo/torch/nn/parallel/data_parallel/data_parallel.py
nijkah Mar 2, 2023
28e474a
[Clean] Remove unused code
jinwonkim93 Mar 2, 2023
5401fa1
[Clean] Merge _DistirbutedDataParallelWrapper to _DistributedDataPara…
jinwonkim93 Mar 2, 2023
fe9ff2f
[Fix] fix forward max recursion
jinwonkim93 Mar 2, 2023
1c581c8
[Clean] Remove parameters
jinwonkim93 Mar 2, 2023
cc0e6a8
Update oslo/torch/nn/parallel/data_parallel/data_parallel.py
nijkah Mar 2, 2023
7494c9c
[Refactor] Remove zero_grad from forward
jinwonkim93 Mar 2, 2023
a5e0c9c
[Fix] Support long tensor for DDP backward (#146)
KKIEEK Mar 3, 2023
894d66d
[Add] Add copyright
jinwonkim93 Mar 3, 2023
4b14253
[Refactor] move zero_grad to parallelize
jinwonkim93 Mar 4, 2023
8b72f4d
[Refactor] refactor backward
KKIEEK Mar 4, 2023
313d325
[Feat] support gloo backend
jinwonkim93 Mar 5, 2023
50814c3
[Fix] fix conflict
jinwonkim93 Mar 8, 2023
b2282de
[Comment] change backward comment
jinwonkim93 Mar 8, 2023
32dd163
[Fix] fix conflict
jinwonkim93 Mar 8, 2023
8a61b92
Merge branch 'main' into distributed_data_parallel
jinwonkim93 Mar 8, 2023
78f003a
[Style] change __all__ to __ALL__
jinwonkim93 Mar 8, 2023
520eb53
[Style] change __all__ to __ALL__
jinwonkim93 Mar 8, 2023
2408bf0
[Style] change zero __all__ to __ALL__
jinwonkim93 Mar 8, 2023
697653b
Update oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/bookke…
jinwonkim93 Mar 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion oslo/torch/nn/parallel/data_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from oslo.torch.nn.parallel.data_parallel.data_parallel import (
DistributedDataParallel,
)
from oslo.torch.nn.parallel.data_parallel.zero import ZeroRedundancyOptimizer

__all__ = ["ZeroRedundancyOptimizer"]
__ALL__ = ["DistributedDataParallel", "ZeroRedundancyOptimizer"]
jinwonkim93 marked this conversation as resolved.
Show resolved Hide resolved
111 changes: 111 additions & 0 deletions oslo/torch/nn/parallel/data_parallel/_reducer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import functools
from typing import Callable, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup


class Bucket:
def __init__(
self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup
):
self.buffer = torch.zeros(size, dtype=dtype, device=device)
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []

def flush(self) -> None:
"""Flush content of the bucket."""
if self.offset == 0:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
dist.all_reduce(self.buffer[: self.offset], group=self.group)

# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.offset = 0
self.callbacks.clear()
self.buffer = torch.zeros_like(self.buffer)

def alloc(self) -> None:

if self.buffer.storage().size() == 0:
self.buffer.storage().resize_(self.buffer.numel())

def free(self) -> None:

assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
self.buffer.storage().resize_(0)

def append(self, tensor: Tensor, callback_fn: Callable):
tensor_size = tensor.numel()
offset = self.offset
self.buffer[offset : offset + tensor_size].copy_(tensor.flatten())
self.offset += tensor_size

# callback will be given the reduced result
if callback_fn is not None:
result_view = self.buffer[offset : offset + tensor_size].view(tensor.shape)
self.callbacks.append(functools.partial(callback_fn, result_view))

@property
def avail_size(self) -> int:
return self.buffer.size(0) - self.offset


class Reducer:
def __init__(self, bucket_size_mb: int = 25):
self.bucket_size_mb = bucket_size_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}

@torch.no_grad()
def all_reduce_async(
self,
tensor: Tensor,
group: ProcessGroup,
callback_fn: Optional[Callable] = None,
) -> None:
bucket_size = self._get_bucket_size(tensor.element_size())

if tensor.numel() >= bucket_size:
dist.all_reduce(tensor, group=group)
if callback_fn is not None:
callback_fn(tensor)
return

bucket = self._get_bucket(tensor, group)
if tensor.numel() > bucket.avail_size:
# not enough space remaining in bucket, flush it now
bucket.flush()
bucket.append(tensor, callback_fn)

@torch.no_grad()
def flush(self) -> None:
for bucket in self.buckets.values():
bucket.flush()

@torch.no_grad()
def free(self) -> None:
for bucket in self.buckets.values():
bucket.free()

@functools.lru_cache()
def _get_bucket_size(self, element_size: int) -> int:
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_size_mb * MB / element_size
return int(bucket_size)

def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
bucket_size = self._get_bucket_size(tensor.element_size())
self.buckets[key] = Bucket(bucket_size, tensor.dtype, tensor.device, group)
self.buckets[key].alloc()
return self.buckets[key]
24 changes: 24 additions & 0 deletions oslo/torch/nn/parallel/data_parallel/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Iterable

import torch


def is_ddp_ignored(p):
return getattr(p, "_ddp_to_ignore", False)


def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
"""Sets parameters to be ignored by DDP.
This method must be called before initializing DistributedDataParallel.
Example:
>>> params_to_ignore = []
>>> for p in module.parameters():
>>> if should_ignore(p):
>>> params_to_ignore.append(p)
>>> set_params_to_ignore(params_to_ignore)
>>> module = DistributedDataParallel(module)
Args:
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
"""
for p in params_to_ignore:
p._ddp_to_ignore = True
232 changes: 232 additions & 0 deletions oslo/torch/nn/parallel/data_parallel/data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import copy
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Optional, Set

import torch
import torch.nn as nn

try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"

from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.distributed.parallel_mode import ParallelMode
from oslo.torch.nn.parallel.utils import (
get_parallel_context,
add_wrapper,
OsloParallelWrapper,
)
from ._reducer import Reducer
nijkah marked this conversation as resolved.
Show resolved Hide resolved
from ._utils import is_ddp_ignored


def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
assert data.storage_offset() == 0
data.storage().resize_(0)


def _cast_float(args, dtype: torch.dtype):
if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
args = args.to(dtype)
elif isinstance(args, (list, tuple)):
args = type(args)(_cast_float(t, dtype) for t in args)
elif isinstance(args, dict):
args = {k: _cast_float(v, dtype) for k, v in args.items()}
return args


class BackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, *args):
if not isinstance(module, _DistributedDataParallel):
raise ValueError
ctx.module = module
ctx.mark_dirty(*args)
return args

@staticmethod
def backward(ctx, *grad_outputs):
ctx.module._backward()
return (None,) + grad_outputs


def DistributedDataParallel(
module: nn.Module,
parallel_context: ParallelContext,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True,
):
ddp = _DistirbutedDataParallelWrapper(
module=module,
parallel_context=parallel_context,
bucket_cap_mb=bucket_cap_mb,
rebuild_bucket=rebuild_bucket,
)

add_wrapper(
module,
mode=ParallelMode.DATA,
wrapper=ddp,
parallel_context=parallel_context,
)
return module


class _DistirbutedDataParallelWrapper(OsloParallelWrapper):
"""Distributed data parallel wrapper for Oslo.
Example:
>>> from oslo.torch.nn.parallel import DistributedDataParallel as DDP
>>> model = torch.nn.Linear(20, 1)
>>>
>>> model = DDP(model, parallel_context)
>>> olso.ready(model, parallel_context)
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
nijkah marked this conversation as resolved.
Show resolved Hide resolved
Args:
module (nn.Module): PyTorch module object
parallel_context (ParallelContext): process group object
"""

def __init__(
self,
module: torch.nn.Module,
parallel_context: ParallelContext = None,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True,
) -> None:
super().__init__(parallelism_priority=99)
self.module = module
self.parallel_context = get_parallel_context(module, parallel_context)
self.bucket_cap_mb = bucket_cap_mb
self.rebuild_bucket = rebuild_bucket

def forward(self, *args, **kwargs):
return self.module_forward(*args, **kwargs)

def deparallelize(self):
self.module.deparallelize()
nijkah marked this conversation as resolved.
Show resolved Hide resolved

def parallelize(self):
jinwonkim93 marked this conversation as resolved.
Show resolved Hide resolved
self.module = _DistributedDataParallel(
self.module, self.parallel_context, self.bucket_cap_mb, self.rebuild_bucket
)
self.module_forward = copy.copy(self.module.forward)


class _DistributedDataParallel(nn.Module):
"""Distributed data parallel for Oslo.

Args:
module (nn.Module): PyTorch module object
parallel_context (ParallelContext): process group object
"""

def __init__(
self,
module: torch.nn.Module,
parallel_context: ParallelContext = None,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True,
) -> None:
super().__init__()
self.module = module
self.module.zero_grad = self.zero_grad
self.module_forward = module.forward

self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
assert parallel_context
self.parallel_context = get_parallel_context(module, parallel_context)
self.dp_world_size = self.parallel_context.get_world_size(ParallelMode.DATA)

self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket
for p in module.parameters():
if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))

def parameters(self, recurse: bool = True):
return self.module.parameters(recurse)
jinwonkim93 marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, *args, **kwargs):
self.module.zero_grad(set_to_none=True)
jinwonkim93 marked this conversation as resolved.
Show resolved Hide resolved
args = (arg.requires_grad_().clone() for arg in args)
args = BackwardFunction.apply(self, *args)
return self.module_forward(*args, **kwargs)

def _backward(self):
with torch.cuda.stream(self.comm_stream):
self.reducer.flush()
torch.cuda.current_stream().wait_stream(self.comm_stream)
if self.rebuild_bucket:
self.reducer.free()
for p in self.module.parameters():
if is_ddp_ignored(p):
continue
if p.grad.device.type != "cpu":
p.grad = p._saved_grad

def grad_handle(self, p, grad):
if grad.device.type != "cpu":
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
self.reducer.all_reduce_async(
grad,
group=self.parallel_context.get_group(ParallelMode.DATA),
callback_fn=partial(self._save_grad, p),
)
grad.record_stream(self.comm_stream)
else:
_DistributedDataParallel._save_grad(p, grad)
return empty_grad

else:
# TODO(jiaruifang) fixme
# self.process_group.set_cpu_groups() # TODO
# dist.all_reduce(
# grad, group=self.process_group.cpu_dp_process_group()
# ) # TODO
# return grad
raise NotImplementedError

@staticmethod
def _save_grad(p, grad):
if hasattr(p, "_saved_grad"):
p._saved_grad.add_(grad)
else:
p._saved_grad = grad

def zero_grad(self, set_to_none: bool = False) -> None:
super().zero_grad(set_to_none=True)
for p in self.module.parameters():
if getattr(p, "_saved_grad", None) is not None:
if set_to_none:
p._saved_grad = None
else:
if p._saved_grad.grad_fn is not None:
p._saved_grad.detach_()
else:
p._saved_grad.requires_grad_(False)
p._saved_grad.zero_()

def state_dict(self, destination=None, prefix="", keep_vars=False):
return self.module.state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)

def load_state_dict(
self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True
):
return self.module.load_state_dict(state_dict, strict)
nijkah marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions oslo/torch/nn/parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def parallelize(model: nn.Module, parallel_context: ParallelContext):
if hasattr(wrapper, "parallelize"):
wrapper.parallelize()
setattr(model, "forward", wrapper.forward)
if hasattr(wrapper, "backward"):
setattr(model, "backward", wrapper.backward)
jinwonkim93 marked this conversation as resolved.
Show resolved Hide resolved

for parameter in model.parameters():
if hasattr(parameter, "oslo_parallel"):
Expand Down