Skip to content

Commit

Permalink
[DRAFT] Tentative implementation of MiCS (deepspeedai#2964)
Browse files Browse the repository at this point in the history
* include mics config and optimizer

* change private vars to public vars

so the child class can initialize these vars

* Port the init function from stage3

* adding a model test file for mics

* adopt to get_acceleartor api and fp16 group defrag

* WIP: porting mics modification to ms master

* WIP: included gradient all-reduce among replication groups

* WIP: ported hierarchical all gather part

did basic loss test on a simple MLP model

* [Bug fix] using the comm group attached on the param

* torch2.0 support

* remove print

* delegate wait op

* [Bug] fix naming

* adding doc string

* resolving recursive import

* fix formating, typo and license

* fix license and unit test error

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-14-191.us-west-2.compute.internal>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-7-70.us-west-2.compute.internal>
Co-authored-by: Zhen Zhang <zhzhn@amazon.com>
Co-authored-by: zhzhn <zhzhn@ip-10-2-57-114.us-west-2.compute.internal>
  • Loading branch information
5 people authored Apr 26, 2023
1 parent d925395 commit 2e99f6e
Show file tree
Hide file tree
Showing 13 changed files with 1,132 additions and 116 deletions.
36 changes: 36 additions & 0 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,30 @@ def reduce_scatter(output,
return cdb.reduce_scatter(output=output, input_list=input_list, op=op, group=group, async_op=async_op)


def has_all_reduce_coalesced():
""""""
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_all_reduce_coalesced is not None, 'has_all_reduce_coalesced is not yet defined'
return cdb.has_all_reduce_coalesced


def has_coalescing_manager():
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_coalescing_manager is not None, 'has_coalescing_manager is not yet defined'
return cdb.has_coalescing_manager


def all_gather_coalesced(output_tensors, input_tensors, group=None, async_op=False):
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.all_gather_coalesced(output_tensors, input_tensors, group=group, async_op=async_op)


@timed_op
def all_reduce(tensor,
op=ReduceOp.SUM,
Expand All @@ -457,6 +481,18 @@ def all_reduce(tensor,
return cdb.all_reduce(tensor, op, group, async_op)


@timed_op
def all_reduce_coalesced(tensors,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='all_reduce',
debug=get_caller_func()):
global cbd
return cdb.all_reduce_coalesced(tensors, op, group, async_op)


def get_world_group():
global cdb
assert cdb is not None and cdb.is_initialized(
Expand Down
63 changes: 62 additions & 1 deletion deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,32 @@
from .backend import *
from .comm import *

import torch


def is_torch_two():
TORCH_MAJOR = int(torch.__version__.split('.')[0])
if TORCH_MAJOR >= 2:
return True
else:
return False


def has_coalescing_manager():
has_c10d = hasattr(torch.distributed, 'distributed_c10d')
return has_c10d and hasattr(torch.distributed.distributed_c10d, '_coalescing_manager')


def has_all_reduce_coalesced():
return hasattr(torch.distributed, "all_reduce_coalesced")


def get_coalescing_manager(group, device, reqs):
if is_torch_two():
return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, reqs=reqs)
else:
return torch.distributed.distributed_c10d._coalescing_manager(group, reqs)


class TorchBackend(Backend):
"""
Expand All @@ -21,6 +47,8 @@ class TorchBackend(Backend):

def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
super(TorchBackend, self).__init__()
self.has_all_reduce_coalesced = has_all_reduce_coalesced()
self.has_coalescing_manager = has_coalescing_manager()
self.all_gather_function = self.get_all_gather_function()
self.reduce_scatter_function = self.get_reduce_scatter_function()
self.initialized = True
Expand Down Expand Up @@ -66,6 +94,16 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, asyn
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
""" proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
"""
if not self.has_all_reduce_coalesced:
raise RuntimeError(f"Current torch version does not have all_reduce_coalesced "
f"api (torch.__version__: {torch.__version__})")
op = self._reduce_op(op)
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)

def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)

Expand All @@ -89,11 +127,34 @@ def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_
group=group,
async_op=async_op)
else:
utils.logger.warning("unable to find torch.distributed.all_gather_into_tensor. will fall back to "
utils.logger.warning("unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
pass

def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
""""""
assert len(output_tensors) == len(input_tensors), ""
if hasattr(torch.distributed.distributed_c10d, '_all_gather_base_coalesced'):
# customized PyTorch
return torch.distributed.distributed_c10d._all_gather_base_coalesced(output_tensors,
input_tensors,
group=group,
async_op=async_op)
elif has_coalescing_manager():
reqs = []
with get_coalescing_manager(group, input_tensors[0].device, reqs):
for output, input in zip(output_tensors, input_tensors):
handle = torch.distributed.distributed_c10d.all_gather_into_tensor(output,
input,
group=group,
async_op=True)
reqs.append(handle)
if async_op:
return reqs[-1]
else:
reqs[-1].wait()

def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,8 @@ def _initialize_params(self, param_dict):
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)

self.zero_config = get_zero_config(param_dict)
self.mics_shard_size = self.zero_config.mics_shard_size
self.mics_hierarchial_params_gather = self.zero_config.mics_hierarchical_params_gather
self.zero_optimization_stage = self.zero_config.stage
self.zero_enabled = self.zero_optimization_stage > 0

Expand Down
44 changes: 44 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,9 @@ def zero_sub_group_size(self):
def zero_optimization_stage(self):
return self._config.zero_optimization_stage

def mics_shard_size(self):
return self._config.mics_shard_size

def zero_reduce_bucket_size(self):
return self._config.zero_config.reduce_bucket_size

Expand Down Expand Up @@ -1368,6 +1371,8 @@ def _configure_bf16_optimizer(self, optimizer):

def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage()
mics_shard_size = self.mics_shard_size()

model_dtype, grad_accum_dtype = self.get_data_types()
timers = self.timers if self.wall_clock_breakdown() else None

Expand Down Expand Up @@ -1443,6 +1448,14 @@ def _configure_zero_optimizer(self, optimizer):
offload_param_config=self.zero_offload_param(),
mpu=self.mpu)
else:
log_dist(
f'Creating fp16 ZeRO stage {zero_stage} optimizer,'
f' MiCS is enabled {mics_shard_size>0},'
f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}',
ranks=[0])
if mics_shard_size > 0:
return self._return_mics_optimizer(optimizer, timers)

log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
optimizer = DeepSpeedZeroOptimizer_Stage3(
Expand Down Expand Up @@ -1479,6 +1492,37 @@ def _configure_zero_optimizer(self, optimizer):

return optimizer

def _return_mics_optimizer(self, basic_optimizer, timers):
from deepspeed.runtime.zero.mics import MiCS_Optimizer
optimizer = MiCS_Optimizer(self.module,
basic_optimizer,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
sub_group_size=self.zero_sub_group_size(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
aio_config=self.aio_config(),
communication_data_type=self.communication_data_type)
return optimizer

def _configure_eigenvalue(self):
eigenvalue = Eigenvalue(
verbose=self.eigenvalue_verbose(),
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/zero/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@

from .tiling import TiledLinear
from .tiling import TiledLinearReturnBias

from .mics import MiCS_Init
3 changes: 3 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
between optimizer steps) or GPU count (increased parallelism).
"""

mics_shard_size: int = Field(-1, new_param="mics_shard_size")

mics_hierarchical_params_gather: bool = False
memory_efficient_linear: bool = True
"""
Use memory efficient linear implementation, for Stage 3.
Expand Down
Loading

0 comments on commit 2e99f6e

Please sign in to comment.