Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

remove weight caching #181

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,6 @@ model.foo.bar.fc2.sequence_parallel = True
# the rest of the flow is the same as the single GPU flow
```

## weight caching (very experimental)

```python
import float8_experimental.config as config

m = Model(...)
# before converting to `Float8Linear`, turn on weight cache buffer allocation
config.allocate_float8_weight_cache_buffers = True

# in the training loop, manually control the global weight caching setting
for idx in N_ITER:
...
if idx % n_microbatch == 0:
# if we are in the first pass of a new microbatch, repopulate the cache
config.weight_cache_enabled = False
elif idx % n_microbatch == 1:
# if we are in the second pass of a new microbatch, use cached weight
# this persists until `idx % n_microbatch == 0` again
config.weight_cache_enabled = True
...
```

# high level technical design

## UX
Expand Down
17 changes: 0 additions & 17 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,6 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

#
# Weight caching.
#

# If True, allocates buffers for float8 weight cache
allocate_float8_weight_cache_buffers = False

# A global flag for controlling the weight cache, off by default. Intended
# usage is for users to modify this from their training loop directly
# according to their microbatching/pipeline parallel setup.
# Note: this is currently a global flag for simplicity and dynamo performance.
weight_cache_enabled = False

#
# Other
#

# If True, on the first iteration of Float8Linear the amaxes will be
# initialized with the incoming data. As of 2023-12-30, this doesn't work
# with autocast + torch.compile + FSDP. Enabling this option is nice for
Expand Down
34 changes: 1 addition & 33 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@

import torch

from float8_experimental.float8_tensor import (
calculate_amax_and_cast_to_float8,
Float8Tensor,
)
from float8_experimental.float8_tensor import Float8Tensor

from float8_experimental.float8_utils import (
amax_history_to_scale,
Expand Down Expand Up @@ -182,15 +179,6 @@ def __init__(self, *args, **kwargs):
# and torch.compile, this option can disable them
self.enable_pre_and_post_forward = config.enable_pre_and_post_forward

if config.allocate_float8_weight_cache_buffers:
# this is a buffer to get `to(dtype)` for free
# TODO(future): hide this from serialization
# TODO(future): force this to stay in float8_e4m3fn
self.register_buffer(
"cached_fp8_weight",
torch.empty(self.weight.shape, dtype=torch.float8_e4m3fn),
)

def register_always_float32_buffer(
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
) -> None:
Expand Down Expand Up @@ -247,32 +235,12 @@ def cast_w_to_float8(
is_amax_initialized,
)

if config.weight_cache_enabled:
assert config.allocate_float8_weight_cache_buffers, (
"float8 weight cache buffer must be allocated using "
+ "`allocate_float8_weight_cache_buffers` to use the weight cache"
)
w_bits_fp8 = self.cached_fp8_weight
else:
# manual calculation of fp8 bits:
# 1. calculate the bits without Float8Tensor, without grad
# 2. store the bits here
# 3. create Float8Tensor from the bits calculated in 2
# motivation: this will take care of saving the bits without
# interacting with tensor subclasses, as w_fp8._data is not
# currently traceable by dynamo
w_bits_fp8 = calculate_amax_and_cast_to_float8(
self.weight, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w
)
if config.allocate_float8_weight_cache_buffers:
self.cached_fp8_weight.copy_(w_bits_fp8)
w_fp8 = Float8Tensor.to_float8(
w,
self.fp8_scale_w,
torch.float8_e4m3fn,
self.fp8_amax_w,
self.emulate,
cached_casted_weight=w_bits_fp8,
)
return w_fp8

Expand Down
2 changes: 0 additions & 2 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ def sync_float8_amax_and_scale_history(

for idx in range(len(fp8_layers)):
child = fp8_layers[idx]
# TODO(future): enable skipping weight related syncing if weight cache
# is on

#
# 1. in distributed contexts, syncs amax values across workers
Expand Down
37 changes: 8 additions & 29 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@
aten = torch.ops.aten


@torch.no_grad()
def calculate_amax_and_cast_to_float8(tensor, scale, float8_dtype, amax_buffer):
if amax_buffer is not None:
amax_buffer.fill_(tensor_to_amax(tensor))

tensor_scaled = tensor * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
return bits_fp8


@torch._dynamo.allow_in_graph
class ToFloat8ConstrFunc(torch.autograd.Function):
"""
Expand All @@ -36,23 +26,20 @@ def forward(
float8_dtype=torch.float8_e4m3fn,
amax_buffer=None,
emulate: bool = False,
cached_casted_weight=None,
):
if cached_casted_weight is not None:
return Float8Tensor(
cached_casted_weight, scale, tensor.dtype, emulate=emulate
)
bits_fp8 = calculate_amax_and_cast_to_float8(
tensor, scale, float8_dtype, amax_buffer
)
if amax_buffer is not None:
amax_buffer.fill_(tensor_to_amax(tensor))

tensor_scaled = tensor * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate)

@staticmethod
def backward(ctx, g):
if isinstance(g, Float8Tensor):
return g.to_original_precision(), None, None, None, None, None
return g.to_original_precision(), None, None, None, None
else:
return g, None, None, None, None, None
return g, None, None, None, None


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -147,14 +134,7 @@ def to_original_precision(self):

@staticmethod
@torch._dynamo.allow_in_graph
def to_float8(
tensor,
scale,
float8_dtype,
amax_buffer=None,
emulate: bool = False,
cached_casted_weight=None,
):
def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False):
"""Converts a higher precision tensor to float8 in a differentiable way.

Args:
Expand All @@ -172,7 +152,6 @@ def to_float8(
float8_dtype,
amax_buffer,
emulate,
cached_casted_weight,
)

@classmethod
Expand Down
33 changes: 0 additions & 33 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
import warnings
from enum import Enum

import float8_experimental.config as config
import float8_experimental.float8_linear as float8_linear

import pytest

import torch
Expand Down Expand Up @@ -234,36 +231,6 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
y.dtype == torch.bfloat16
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"

@pytest.mark.parametrize("use_compile", [False, True])
def test_weight_caching(self, use_compile):
M, K, N = 16, 32, 64
dtype = torch.bfloat16
config.allocate_float8_weight_cache_buffers = True

x = torch.randn(M, K, device="cuda", dtype=dtype)
m_ref = nn.Linear(K, N, bias=True, device="cuda", dtype=dtype)
m = Float8Linear.from_float(copy.deepcopy(m_ref), emulate=False)

if use_compile:
m = torch.compile(m)

config.weight_cache_enabled = False

y1 = m(x)
y1.sum().backward()
grad1 = m.weight.grad.clone().detach()

config.weight_cache_enabled = True
sync_float8_amax_and_scale_history(m)

y2 = m(x)
y2.sum().backward()
grad2 = m.weight.grad.clone().detach()

torch.testing.assert_close(grad2, grad1 * 2)

config.allocate_float8_weight_cache_buffers = False


class TestScaledMM:
@unittest.skipIf(
Expand Down