Skip to content

Commit

Permalink
Merge branch 'master' into gma/uneven_heads
Browse files Browse the repository at this point in the history
  • Loading branch information
molly-smith authored Sep 1, 2023
2 parents 6c3c841 + 844eb68 commit 58e8b24
Show file tree
Hide file tree
Showing 20 changed files with 430 additions and 12 deletions.
1 change: 1 addition & 0 deletions .github/workflows/amd-mi200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: amd-mi200
on:
schedule:
- cron: "0 0 * * *"
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nv-h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: nv-h100
on:
schedule:
- cron: "0 0 * * *"
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
## Latest News
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>

* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses)
* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)]
* [2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀
* [2023/03] [Scaling Large-Scale Generative Mixture-of-Expert Multimodal Model With VL-MoE](https://www.deepspeed.ai/2023/03/30/multi-modal.html)
* [2023/02] [Automatic Tensor Parallelism: Enables tensor parallelism by default without an injection policy](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/)

---

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
352 changes: 352 additions & 0 deletions blogs/deepspeed-chat/ds-chat-release-8-31/README.md

Large diffs are not rendered by default.

15 changes: 11 additions & 4 deletions deepspeed/module_inject/containers/megatron_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,21 @@ def __init__(self, client_module, inference=True):
try:
from megatron.model.transformer import ParallelTransformerLayer
MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer
MegatronLayerPolicy.version = 1
except ImportError:
MegatronLayerPolicy._orig_layer_class = None

def get_hidden_heads(self):
return self.client_module.attention.query_key_value.weight.shape[1], \
self.client_module.attention.num_attention_heads, \
self.client_module.input_layernorm.eps, \
DEFAULT_INTERMEDIATE_SIZE
if MegatronLayerPolicy.version == 0:
return self.client_module.attention.query_key_value.weight.shape[1], \
self.client_module.attention.num_attention_heads, \
self.client_module.input_layernorm.eps, \
DEFAULT_INTERMEDIATE_SIZE
else:
return self.client_module.self_attention.query_key_value.weight.shape[1], \
self.client_module.self_attention.num_attention_heads, \
self.client_module.input_layernorm.eps, \
DEFAULT_INTERMEDIATE_SIZE

def attention(self, enable_training=False):
if self.inference:
Expand Down
9 changes: 5 additions & 4 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,19 +490,20 @@ def _run_after_backward_function(sub_module):
# post backward hook
self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))

@torch.no_grad()
def pre_sub_module_forward_function(self, sub_module):
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)

prev_grad_state = torch.is_grad_enabled(
) # we don't want to enable grad for sub modules fetching, yet the subfunction need to know if grad is enabled
torch.set_grad_enabled(False)
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(sub_module)

param_coordinator = self.get_param_coordinator(training=sub_module.training)
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module, forward=True)

param_coordinator.fetch_sub_module(sub_module, forward=prev_grad_state)
torch.set_grad_enabled(prev_grad_state)
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)

@torch.no_grad()
Expand Down
55 changes: 55 additions & 0 deletions docs/_tutorials/mixed_precision_zeropp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
---
title: "Mixed Precision ZeRO++"
tags: training ZeRO communication-efficiency large-model
---

Mixed Precision ZeRO++ (MixZ++) is a set of optimization strategies based on [ZeRO](/tutorials/zero/) and [ZeRO++](/tutorials/zeropp/) to improve the efficiency and reduce memory usage for large model training and inference when users use [Low-Rank Adaptation (LoRA)]([/tutorials/zero/](https://arxiv.org/abs/2106.09685)) training. MixZ++ partitions model parameters across GPUs to reduce footprint and gathers them with quantized communication only when needed similar to its ZeRO and ZeRO++ siblings. Our evaluation indicates MixZ++ increases the training throughput by up to [3.3x](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31) for the Llama-2-70B model running on 128 V100 GPUs. Read our [DeepSpeed Chat Blog](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31), [ZeRO++ blog](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/) and [paper](https://arxiv.org/pdf/2306.10209.pdf) to learn more!

We recommend that you read the tutorials on [Getting Started](/getting-started/), [ZeRO](/tutorials/zero/) and [Megatron-DeepSpeed](/tutorials/megatron/) before stepping through this tutorial.

## Key Designs
Mixed Precision ZeRO++ (MixZ++) inherits key designs from [ZeRO++](/tutorials/zeropp/), namely quantized weights (*qwZ*), hierarchical partitioning ZeRO (*hpZ*) but has different applicability:
- *qwZ* applies block-based quantization on frozen weights to reduce memory usage and all-gather communication volume. Compared with ZeRO++, *qwZ* in Mixed Precision ZeRO++ keeps the frozen weights quantized so there is no quantization overhead during runtime and memory usage is reduced.
- *hpZ* eliminates inter-node parameter all-gather communication through data remapping and recomputation. Compared with ZeRO++, *hpZ* in Mixed Precision ZeRO++ applies to both backward and generation passes.

Collectively, the optimizations bring better scalability and efficiency to LoRA training. Each of the components can be enabled independent of each other and collectively as a group.

## Enabling Mixed Precision ZeRO++ (MixZ++)

A ready to go MixZ++ example has been prepared at [MixZ++ example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b_mixz.sh). If you prefer to manually enable MixZ++ in your pipeline, please refer to the instructions below.

### DeepSpeed Configuration Changes
An example snippet of deepspeed configurations with all MixZ++ optimization enabled is shown below:
```json
{
"zero_optimization": {
"stage": 3,
"..."
"zero_quantized_nontrainable_weights": true,
"zero_hpz_partition_size": 16,
"..."
}
}
```
Note that for multi-node training, the `"zero_hpz_partition_size"` should be set to the number of GPUs per node. For example, if you have 8 GPUs per node, then `"zero_hpz_partition_size"` should be set to 8. For single-node training, the `"zero_hpz_partition_size"` should not be set.

### Training Script Changes
DeepSpeed engine will identify the LoRA frozen parameters if the LoRA model is passed when DeepSpeed initializes. However, the popular implementation is to initialize a base model and then convert to LoRA model later. In such cases, users need to explicitly call DeepSpeed engine after LoRA model is converted. This is only a 1-line effort. An example snippet of training script is shown below:

```python
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
config=ds_config,
lr_scheduler=lr_scheduler,
dist_init_required=True)
# ...
# (the custom code to convert base model to LoRA model)
# ...
# call DeepSpeed engine again to identify LoRA frozen parameters
model.optimizer.quantize_nontrainable_params()
# ...
```

Congratulations! You have completed the Mixed Precision ZeRO++ tutorial.
3 changes: 1 addition & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ title: "Latest News"
---
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>

* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses)
* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)]
* [2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀
* [2023/03] [Scaling Large-Scale Generative Mixture-of-Expert Multimodal Model With VL-MoE](https://www.deepspeed.ai/2023/03/30/multi-modal.html)
* [2023/02] [Automatic Tensor Parallelism: Enables tensor parallelism by default without an injection policy](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/)


# Extreme Speed and Scale for DL Training and Inference

Expand Down
2 changes: 2 additions & 0 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,8 @@ def cxx_args(self):
if not self.build_for_cpu:
if not self.is_rocm_pytorch():
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
if not os.path.exists(CUDA_LIB64):
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib")
else:
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib")

Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.10.2
0.10.3

0 comments on commit 58e8b24

Please sign in to comment.