Skip to content

Commit

Permalink
Merge branch 'master' into olruwase/no_grad_sync_ctxt
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Nov 8, 2024
2 parents e48130e + 057d25b commit 825b3d6
Show file tree
Hide file tree
Showing 23 changed files with 237 additions and 69 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/nv-ds-chat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ on:
type: string
pull_request:
paths:
- ".github/workflows/nv-ds-chat.yml"
- "deepspeed/runtime/zero/stage_1_and_2.py"
- "deepspeed/runtime/zero/stage3.py"
- "deepspeed/runtime/hybrid_engine.py"
Expand Down Expand Up @@ -42,6 +43,7 @@ jobs:
- name: Install deepspeed
run: |
pip install transformers==4.45.2
pip install .[dev]
ds_report
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ jobs:
unit-tests:
strategy:
matrix:
pyVersion: ["3.7", "3.8", "3.9", "3.10"]
pyVersion: ["3.8", "3.9", "3.10", "3.11", "3.12"]
fail-fast: false

runs-on: ubuntu-22.04
runs-on: ubuntu-24.04
container:
image: deepspeed/gh-builder:py${{ matrix.pyVersion }}

Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/google/yapf
rev: v0.32.0
rev: v0.40.0
hooks:
- id: yapf

Expand Down Expand Up @@ -65,7 +65,7 @@ repos:
]

- repo: https://github.com/pycqa/flake8
rev: 4.0.1
rev: 5.0.4
hooks:
- id: flake8
args: ['--config=.flake8']
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ DeepSpeed has been integrated with several different popular open-source DL fram
| PyTorch Nightly | [![nv-torch-nightly-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml) |
| Integrations | [![nv-transformers-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml) [![nv-lightning-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml) [![nv-accelerate-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml) [![nv-mii](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml) [![nv-ds-chat](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml) [![nv-sd](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml) |
| Misc | [![Formatting](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml) [![pages-build-deployment](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment) [![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)[![python](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml) |
| Huawei Ascend NPU | [![Huawei Ascend NPU](https://github.com/cosdt/DeepSpeed/actions/workflows/huawei-ascend-npu.yml/badge.svg?branch=master)](https://github.com/cosdt/DeepSpeed/actions/workflows/huawei-ascend-npu.yml) |
| Huawei Ascend NPU | [![Huawei Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/deepspeed.yaml/badge.svg?branch=main)](https://github.com/Ascend/Ascend-CI/actions/workflows/deepspeed.yaml) |

# Installation

Expand Down
10 changes: 9 additions & 1 deletion csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,15 @@ __global__ void apply_rotary_pos_half(T* mixed_query,

#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64
#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \
if (threads_per_head == 64) { \
if (threads_per_head == 4) { \
LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \
} else if (threads_per_head == 8) { \
LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \
} else if (threads_per_head == 16) { \
LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \
} else if (threads_per_head == 32) { \
LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \
} else if (threads_per_head == 64) { \
LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \
} else { \
assert(false); \
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def initialize(args=None,
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
if hasattr(args, "deepspeed_config"):
assert (args.deepspeed_config is
None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
assert (args.deepspeed_config
is None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
args.deepspeed_config = args.deepscale_config
args.deepscale_config = None

Expand Down
12 changes: 6 additions & 6 deletions deepspeed/autotuning/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ def mp_size(self):
return self.autotuning_config.mp_size

def max_train_micro_batch_size_per_gpu(self):
if self.max_train_batch_size(
) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size
if self.max_train_batch_size() and self.max_train_batch_size(
) > 0: # if the user specifies a max_train_batch_size
max_train_micro_batch_size = self.max_train_batch_size() * self.mp_size() // (
self.exp_num_gpus * self.exp_num_nodes) # gradient accumulation steps >=1
return min(self.autotuning_config.max_train_micro_batch_size_per_gpu, max_train_micro_batch_size)
Expand Down Expand Up @@ -964,8 +964,8 @@ def get_min_max_micro_batch_size(self, stage, min_micro_batch_size, calculated_m
low = mid + 1
self.update_records(tuning_space_name, exp, metric_val, 1)
used_micro_batch_sizes.append(mid)
if prev_metric_val and (
(metric_val - prev_metric_val) / prev_metric_val) < METRIC_PERCENT_DIFF_CONST:
if prev_metric_val and ((metric_val - prev_metric_val) /
prev_metric_val) < METRIC_PERCENT_DIFF_CONST:
logger.info(f"performance plateaus at mbs = {low}")
break
prev_metric_val = metric_val
Expand Down Expand Up @@ -1026,8 +1026,8 @@ def get_tuning_micro_batch_size_list(self, min_micro_batch_size, max_micro_batch
# NUM_GPUS=$(( ${NUM_WORKERS} * ${NUM_GPUS_PER_WORKER} ))
# DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) ))
# GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${BATCH_SIZE} * ${DP_SIZE}) ))
if self.max_train_batch_size(
) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size
if self.max_train_batch_size() and self.max_train_batch_size(
) > 0: # if the user specifies a max_train_batch_size
max_train_batch_size_per_gpu = self.max_train_batch_size() * self.mp_size() // (self.exp_num_gpus *
self.exp_num_nodes)
else:
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/elasticity/elastic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def _invoke_run(self, role: str = "default") -> RunResult:
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.")
self._exit_barrier()
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED
} or len(participants) > len(rdzv_handler._state_holder.state.participants):
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED} or len(participants) > len(
rdzv_handler._state_holder.state.participants):
if self._remaining_restarts > 0:
log.info(f"[{role}] Worker group {state.name}. "
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
Expand Down
7 changes: 4 additions & 3 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,10 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
if not dist.is_initialized() or dist.get_rank() == 0:
print("Saving tp-sharded checkpoints")
torch.save(
OrderedDict({k: v
for k, v in dict(replaced_module.state_dict()).items()
if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
OrderedDict({
k: v
for k, v in dict(replaced_module.state_dict()).items() if transformer_name not in k
}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')

dtype_reprs = {
torch.float32: 'float32',
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,8 +1012,8 @@ def _do_error_check(self):
self.gradient_accumulation_steps), "DeepSpeedConfig: {} is not defined".format(GRADIENT_ACCUMULATION_STEPS)

if self.zero_enabled:
assert (self.zero_optimization_stage <=
ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(
assert (self.zero_optimization_stage
<= ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(
ZeroStageEnum.max_stage)

if self.fp16_master_weights_and_gradients:
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/eigenvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def compute_eigenvalue(self, module, device=None, scale=1.0):
eigenvalue_current, eigenvalue_previous = 1., 0.

while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs(
(eigenvalue_current - eigenvalue_previous) / eigenvalue_current) >=
self.tol): # test convergence criteria
(eigenvalue_current - eigenvalue_previous) / eigenvalue_current)
>= self.tol): # test convergence criteria
eigenvalue_previous = eigenvalue_current

Hv = torch.autograd.grad(grads, params, grad_outputs=v, only_inputs=True, retain_graph=True)
Expand Down
7 changes: 4 additions & 3 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,10 @@ def _aggregate_total_loss(self):
self.dp_group_loss = losses[0].clone().detach()
agg_loss = losses[1].clone().detach()
if additional_losses is not None:
self.agg_additional_losses = OrderedDict(
{name: losses[2 + i].clone().detach()
for i, name in enumerate(additional_losses.keys())})
self.agg_additional_losses = OrderedDict({
name: losses[2 + i].clone().detach()
for i, name in enumerate(additional_losses.keys())
})
return agg_loss

def set_dataloader(self, loader):
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ def has_overflow(self, params, has_moe_params=None):
elif self.mpu is not None:
if self.deepspeed is not None:
using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or (
not using_pipeline and self.deepspeed.enable_backward_allreduce is False):
if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce
is False) or (not using_pipeline and self.deepspeed.enable_backward_allreduce is False):
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group())
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group())
elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False:
Expand Down
57 changes: 26 additions & 31 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,28 @@ def __init__(
self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold,
self.model_persistence_threshold)

self.param_coordinators = {}
self._prefetch_bucket_sz = int(prefetch_bucket_size)
self._max_reuse_distance_in_numel = int(max_reuse_distance)
self._max_available_parameters_in_numel = int(max_live_parameters)
self.__allgather_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream(
) if overlap_comm else get_accelerator().default_stream()

if not hasattr(module, "ds_inflight_param_registry"):
module.ds_inflight_param_registry = dict()
# we need two registries, one for training and one for eval. They will be used when creating PartitionedParameterCoordinator
module.ds_inflight_param_registry[True] = InflightParamRegistry()
module.ds_inflight_param_registry[False] = InflightParamRegistry()
module.ds_inflight_param_registry = InflightParamRegistry()
self.__inflight_param_registry = module.ds_inflight_param_registry

self.param_coordinator = PartitionedParameterCoordinator(
prefetch_bucket_sz=self._prefetch_bucket_sz,
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
max_available_parameters_in_numel=self._max_available_parameters_in_numel,
allgather_stream=self.__allgather_stream,
inflight_param_registry=self.__inflight_param_registry,
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
)

self.forward_hooks = []
self.backward_hooks = []
self.setup_zero_stage3_hooks()
Expand All @@ -161,26 +169,13 @@ def partition_all_parameters(self):
"""Partitioning Parameters that were not partitioned usually if parameters
of modules whose input parameters do not require grad computation do not
trigger post call and will therefore will remain unpartitioned"""
self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module)
self.get_param_coordinator().release_and_reset_all(self.module)
for param in iter_params(self.module, recurse=True):
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
raise RuntimeError(f"{param.ds_summary()} expected to be released")

def get_param_coordinator(self, training):
if not training in self.param_coordinators:
self.param_coordinators[training] = PartitionedParameterCoordinator(
prefetch_bucket_sz=self._prefetch_bucket_sz,
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
max_available_parameters_in_numel=self._max_available_parameters_in_numel,
allgather_stream=self.__allgather_stream,
inflight_param_registry=self.__inflight_param_registry[training],
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
)

return self.param_coordinators[training]
def get_param_coordinator(self):
return self.param_coordinator

def empty_partition_cache(self):
self.partition_all_parameters()
Expand Down Expand Up @@ -228,14 +223,14 @@ def setup_zero_stage3_hooks(self):

#reset step if in inference mode
@instrument_w_nvtx
def _end_of_forward_hook(module, *args):
def _start_of_forward_hook(module, *args):

self.get_param_coordinator().reset_step()

if not torch._C.is_grad_enabled():
self.get_param_coordinator(training=False).reset_step()
self.module.register_forward_pre_hook(_start_of_forward_hook)

#likely one of them should be enough but just to be safe
self._register_hooks_recursively(self.module)
self.module.register_forward_hook(_end_of_forward_hook)

# Add top module to stack trace
global FWD_MODULE_STACK
Expand Down Expand Up @@ -447,7 +442,7 @@ def pre_sub_module_forward_function(self, sub_module):
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(sub_module)

param_coordinator = self.get_param_coordinator(training=sub_module.training)
param_coordinator = self.get_param_coordinator()
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
Expand All @@ -460,29 +455,29 @@ def post_sub_module_forward_function(self, sub_module):
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)

param_coordinator = self.get_param_coordinator(training=sub_module.training)
param_coordinator = self.get_param_coordinator()
param_coordinator.release_sub_module(sub_module)

see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
force=False)

@torch.no_grad()
def pre_sub_module_backward_function(self, sub_module):
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
param_coordinator = self.get_param_coordinator(training=True)
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
param_coordinator = self.get_param_coordinator()
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module, forward=False)

@torch.no_grad()
def post_sub_module_backward_function(self, sub_module):
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)

self.get_param_coordinator(training=True).release_sub_module(sub_module)
self.get_param_coordinator().release_sub_module(sub_module)

see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
Expand Down
Loading

0 comments on commit 825b3d6

Please sign in to comment.