Skip to content

Latest commit

 

History

History
37 lines (25 loc) · 2.81 KB

3-Training-at-Scale.md

File metadata and controls

37 lines (25 loc) · 2.81 KB

Chapter 3: Training at Scale

There are a number of techniques that are helpful for training large-scale models efficiently.

This chapter may be helpful if you're going to be working with large-scale models, but it's also reasonable to skip and come back to it if and when it becomes relevant.

Recommended reading

Optional reading

  • Megatron-LM 2 - A detailed analysis of different model parallelism techniques.
  • ZeRO - A way to make data parallelism more memory-efficient (at the cost of additional communication), by partitioning optimizer states, gradients and parameters across data-parallel processes.
  • Mixed Precision Training - Using reduced floating-point precision is important for training efficiency, but comes with a risk of errors and instabilities. This paper describes a typical mixed-precision setup, including a technique called loss scaling used to help minimize underflow.
  • MPI tutorial (can skip the "Introduction and MPI installation" section) - MPI is an interface for passing messages that is often used to implement different kinds of parallelism.
  • Triton tutorial - Most of the numeric operations used to train models are performed on GPUs, which can perform parallel computation efficiently. Triton provides a Python-based programming environment for writing code that can be compiled to run on a GPU.

Suggested exercise

Make a parallelized version of your MNIST script or transformer implementation:

  • Install mpi4py and play around with it to check you understand how it works.
  • Implement data parallelism. Some things to remember:
    • The dataset needs to be sharded
    • The random initialization needs to be broadcast
    • Gradients need to be allreduced
  • Run your code on 2 GPUs in parallel and check that you get the same learning curve as with 1 GPU (with the same global batch size).
  • (Optional) Implement pipeline parallelism and/or op sharding - these are harder

OR for a fun but less useful exercise:

Write an implementation of Adam using Triton, which is "fused" in the sense that the entire update is performed by a single call to low-level GPU code.

  • First write the code for the update rule that can be compiled using @triton.jit decorator.
  • Then write the PyTorch function that calls the compiled update rule.
  • Benchmark your implementation and compare it to torch.optim.Adam.