Skip to content

[shardformer, pipeline]: add gradient_checkpointing_ratio and heterogenous shard policy #5509

@cwher

Description

@cwher

Description

gradient checkpointing is known to be memory efficient yet could slow down the TGS(Tokens per second for a GPU) to 75% (Est. value, assuming time of backward = 2x time of forward). Thus, it is a trade-off between memory limitation and throughput.

However, there are times when the user wants to control this trade-off more precisely. For example, if gradient checkpointing=False results in 40GB memory consumption and gradient checkpointing=True yields 5GB. Those equipped with 30GB memory redundancy may want to enable gradient checkpointing partially so that they can accelerate the training process while avoiding OOM.

This leads to the design of gradient_checkpointing_ratio, which allows users to control gradient_checkpointing more precisely. (FEATURE 1)

Furthermore, there is more potential when gradient_checkpointing_ratio is combined with Pipeline Parallelism (PP).

image

As illustrated in the above figure (copy from http://arxiv.org/abs/2104.04473), different stages of PP store a various number of micro batches' gradient, e.g., in 1F1B, device 1 stores 4 micro-batches while device 4 only stores 1 micro-batch. This nature leads to extremely unbalanced memory consumption across devices.

As a common partition strategy, when a 32-layer model is partitioned on 4 devices, each device possesses 8 layers. Let us assume the activation memory of 1 micro-batch passing through 8 layers is 10GB, and with gradient checkpointing it reduces to 1GB. When executed on a 20GB accelerator, gradient checkpointing must be enabled since device 1 requires 4 * 10GB > 20GB limitation. However, there is only 10GB memory on device 4 even with gradient checkpointing=False. The detailed memory consumption is shown in the following table.

  • 100% execution time, yet OOM.

    Device # Layers # Ckpt Layers Memory
    1 8 0 40
    2 8 0 30
    3 8 0 20
    4 8 0 10
  • 121% execution time, no OOM.

    Device # Layers # Ckpt Layers Memory
    1 8 5 17.5
    2 8 5 13.1
    3 8 5 8.8
    4 8 5 4.4

The key insight is to assign different gradient_checkpointing_ratio to different PP devices. (FEATURE 2) For PP devices with high memory consumption we can assign a higher gradient_checkpointing_ratio. As gradient_checkpointing incurs overhead, fewer layers should be assigned to make the execution time of each pipeline stage balanced. Otherwise, the pipeline stage with a higher gradient_checkpointing_ratio will be the bottleneck. The following illustrates an example of this strategy.

  • 113% execution time, no OOM.

    Device # Layers # Ckpt Layers Memory
    1 7 4 17
    2 8 3 19.9
    3 8 0 20
    4 9 0 11.3

The solution is found by modeling the problem as a mixed-integer programming problem.

import mip
import numpy as np


if __name__ == "__main__":
    num_devices = 4
    memory_bound = 20
    weight = 0
    grad = 0
    total_layers = 32

    activation_mem = np.linspace(40, 0, num_devices + 1).tolist()[:-1]
    ckpt_mem = np.linspace(4, 0, num_devices + 1).tolist()[:-1]
    std_layers = total_layers // num_devices

    model = mip.Model()
    num_layers = [model.add_var(var_type=mip.INTEGER) for _ in range(num_devices)]
    num_ckpt = [model.add_var(var_type=mip.INTEGER) for _ in range(num_devices)]
    # forward_time = model.add_var(var_type=mip.CONTINUOUS)
    # backward_time = model.add_var(var_type=mip.CONTINUOUS)
    forward_backward_time = model.add_var(var_type=mip.CONTINUOUS)

    # Constraints
    model += mip.xsum(num_layers) == total_layers
    for i in range(num_devices):
        model += num_ckpt[i] <= num_layers[i]
    for i in range(num_devices):
        # model += forward_time >= num_layers[i]
        # model += backward_time >= num_layers[i] * 2 + num_ckpt[i]
        model += forward_backward_time >= num_layers[i] * 3 + num_ckpt[i]
    for i in range(num_devices):
        model += activation_mem[i] / std_layers * (num_layers[i] - num_ckpt[i]) + ckpt_mem[i] / std_layers * num_ckpt[i] + (weight + grad) / std_layers * num_layers[i] <= memory_bound

    # Objective
    # 1. Forward phase: max(num_layers)
    # 2. Backward phase: max(2 * num_layers + num_ckpt)
    model.objective = mip.minimize(forward_backward_time)

    model.optimize()

    # print("Forward time:", forward_time.x)
    # print("Backward time:", backward_time.x)
    # print("Total time:", forward_time.x + backward_time.x)
    print("Total time:", forward_backward_time.x)

    std_time = std_layers * 3
    # print("Slow down:", (forward_time.x + backward_time.x) / std_time)
    print("Slow down:", forward_backward_time.x / std_time)

    for i in range(num_devices):
        weight_and_grad = (weight + grad) / std_layers * num_layers[i].x
        activation = activation_mem[i] / std_layers * (num_layers[i].x - num_ckpt[i].x) + ckpt_mem[i] / std_layers * num_ckpt[i].x
        print(
            f"Device {i+1}: {num_layers[i].x:2.0f} layers, {num_ckpt[i].x:2.0f} checkpoints, "
            f"Weight + Grad: {weight_and_grad:2.2f} GB, Activation: {activation:2.2f} GB, Total: {weight_and_grad + activation:2.2f} GB"
        )

In summary, the two features are proposed to provide users with more flexibility to control the trade-off between memory consumption and throughput. The first feature allows users to control the overall gradient_checkpointing_ratio, while the second feature allows users to assign different gradient_checkpointing_ratio to different PP devices.

Methods

#5508 is linked to this issue.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions