Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

support float8 weight caching for gradient accumulation/PP #164

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@ model.foo.bar.fc2.sequence_parallel = True
# the rest of the flow is the same as the single GPU flow
```

## weight caching (very experimental)

```python
import float8_experimental.config as config

m = Model(...)
# before converting to `Float8Linear`, turn on weight cache buffer allocation
config.allocate_float8_weight_cache_buffers = True

# in the training loop, manually control the global weight caching setting
for idx in N_ITER:
...
if idx % n_microbatch == 0:
# if we are in the first pass of a new microbatch, repopulate the cache
config.weight_cache_enabled = False
elif idx % n_microbatch == 1:
# if we are in the second pass of a new microbatch, use cached weight
# this persists until `idx % n_microbatch == 0` again
config.weight_cache_enabled = True
...
```

# high level technical design

## UX
Expand Down
18 changes: 18 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

#
# Weight caching.
#

# If True, allocates buffers for float8 weight cache
allocate_float8_weight_cache_buffers = False
Copy link
Contributor

@drisspg drisspg Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some more comments on how users are expected to use this in code

Sync_float_amax()
if accumulate_grad:
   weight_cache_enabled = Trie

If not accumulate grad:
   optimixer.step()
   weight_cache_enabled=False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added to readme


# A global flag for controlling the weight cache, off by default. Intended
# usage is for users to modify this from their training loop directly
# according to their microbatching/pipeline parallel setup.
# Note: this is currently a global flag for simplicity and dynamo performance.
weight_cache_enabled = False
43 changes: 41 additions & 2 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@

from typing import Optional

import float8_experimental.config as config

import torch

from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import (
calculate_amax_and_cast_to_float8,
Float8Tensor,
)

from float8_experimental.float8_utils import (
amax_history_to_scale,
Expand Down Expand Up @@ -172,6 +177,15 @@ def __init__(self, *args, **kwargs):
# will access the scale when it has ensured that it is on GPU.
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs)

if config.allocate_float8_weight_cache_buffers:
# this is a buffer to get `to(dtype)` for free
# TODO(future): hide this from serialization
# TODO(future): force this to stay in float8_e4m3fn
self.register_buffer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, we can fix that in a future PR!

"cached_fp8_weight",
torch.empty(self.weight.shape, dtype=torch.float8_e4m3fn),
)

def register_always_float32_buffer(
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
) -> None:
Expand Down Expand Up @@ -228,8 +242,33 @@ def cast_w_to_float8(
torch.float8_e4m3fn,
is_amax_initialized,
)

if config.weight_cache_enabled:
assert config.allocate_float8_weight_cache_buffers, (
"float8 weight cache buffer must be allocated using "
+ "`allocate_float8_weight_cache_buffers` to use the weight cache"
)
w_bits_fp8 = self.cached_fp8_weight
else:
# manual calculation of fp8 bits:
# 1. calculate the bits without Float8Tensor, without grad
# 2. store the bits here
# 3. create Float8Tensor from the bits calculated in 2
# motivation: this will take care of saving the bits without
# interacting with tensor subclasses, as w_fp8._data is not
# currently traceable by dynamo
w_bits_fp8 = calculate_amax_and_cast_to_float8(
self.weight, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w
)
if config.allocate_float8_weight_cache_buffers:
self.cached_fp8_weight.copy_(w_bits_fp8)
w_fp8 = Float8Tensor.to_float8(
w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w, self.emulate
w,
self.fp8_scale_w,
torch.float8_e4m3fn,
self.fp8_amax_w,
self.emulate,
cached_casted_weight=w_bits_fp8,
)
return w_fp8

Expand Down
3 changes: 3 additions & 0 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def sync_float8_amax_and_scale_history(

for idx in range(len(fp8_layers)):
child = fp8_layers[idx]
# TODO(future): enable skipping weight related syncing if weight cache
# is on

#
# 1. in distributed contexts, syncs amax values across workers
#
Expand Down
49 changes: 35 additions & 14 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
aten = torch.ops.aten


@torch.no_grad()
def calculate_amax_and_cast_to_float8(tensor, scale, float8_dtype, amax_buffer):
if amax_buffer is not None:
amax_buffer.fill_(tensor_to_amax(tensor))

tensor_scaled = tensor * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
return bits_fp8


class ToFloat8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion to fp8
Expand All @@ -25,24 +35,23 @@ def forward(
float8_dtype=torch.float8_e4m3fn,
amax_buffer=None,
emulate: bool = False,
cached_casted_weight=None,
):
# In TransformerEngine, the casts to float8 are fused with calculating
# the new amax value. In this codebase, the eager mode code for those
# two things is colocated in this function. We expect PT2.0 to fuse it
# for us.
if amax_buffer is not None:
amax_buffer.fill_(tensor_to_amax(tensor))

tensor_scaled = tensor * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
if cached_casted_weight is not None:
return Float8Tensor(
cached_casted_weight, scale, tensor.dtype, emulate=emulate
)
bits_fp8 = calculate_amax_and_cast_to_float8(
tensor, scale, float8_dtype, amax_buffer
)
return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate)

@staticmethod
def backward(ctx, g):
if isinstance(g, Float8Tensor):
return g.to_original_precision(), None, None, None, None
return g.to_original_precision(), None, None, None, None, None
else:
return g, None, None, None, None
return g, None, None, None, None, None


class FromFloat8ConstrFunc(torch.autograd.Function):
Expand Down Expand Up @@ -122,7 +131,7 @@ def __tensor_flatten__(self):
return ["_data", "_scale"], ctx

@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, metadata):
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but fixing it to adhere to the changes in https://github.com/pytorch/pytorch/pull/114311/files

assert len(inner_tensors) == 2
return Float8Tensor(
inner_tensors["_data"],
Expand All @@ -136,7 +145,14 @@ def to_original_precision(self):

@staticmethod
@torch._dynamo.allow_in_graph
def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False):
def to_float8(
tensor,
scale,
float8_dtype,
amax_buffer=None,
emulate: bool = False,
cached_casted_weight=None,
):
"""Converts a higher precision tensor to float8 in a differentiable way.

Args:
Expand All @@ -149,7 +165,12 @@ def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = Fal
Float8Tensor: a float8 tensor
"""
return ToFloat8ConstrFunc.apply(
tensor, scale, float8_dtype, amax_buffer, emulate
tensor,
scale,
float8_dtype,
amax_buffer,
emulate,
cached_casted_weight,
)

@classmethod
Expand Down
33 changes: 33 additions & 0 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import warnings
from enum import Enum

import float8_experimental.config as config
import float8_experimental.float8_linear as float8_linear

import pytest

import torch
Expand Down Expand Up @@ -231,6 +234,36 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
y.dtype == torch.bfloat16
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"

@pytest.mark.parametrize("use_compile", [False, True])
def test_weight_caching(self, use_compile):
M, K, N = 16, 32, 64
dtype = torch.bfloat16
config.allocate_float8_weight_cache_buffers = True

x = torch.randn(M, K, device="cuda", dtype=dtype)
m_ref = nn.Linear(K, N, bias=True, device="cuda", dtype=dtype)
m = Float8Linear.from_float(copy.deepcopy(m_ref), emulate=False)

if use_compile:
m = torch.compile(m)

config.weight_cache_enabled = False

y1 = m(x)
y1.sum().backward()
grad1 = m.weight.grad.clone().detach()

config.weight_cache_enabled = True
sync_float8_amax_and_scale_history(m)

y2 = m(x)
y2.sum().backward()
grad2 = m.weight.grad.clone().detach()

torch.testing.assert_close(grad2, grad1 * 2)

config.allocate_float8_weight_cache_buffers = False


class TestScaledMM:
@unittest.skipIf(
Expand Down