This repository was archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
support float8 weight caching for gradient accumulation/PP #164
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# | ||
# Weight caching. | ||
# | ||
|
||
# If True, allocates buffers for float8 weight cache | ||
allocate_float8_weight_cache_buffers = False | ||
|
||
# A global flag for controlling the weight cache, off by default. Intended | ||
# usage is for users to modify this from their training loop directly | ||
# according to their microbatching/pipeline parallel setup. | ||
# Note: this is currently a global flag for simplicity and dynamo performance. | ||
weight_cache_enabled = False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,14 @@ | |
|
||
from typing import Optional | ||
|
||
import float8_experimental.config as config | ||
|
||
import torch | ||
|
||
from float8_experimental.float8_tensor import Float8Tensor | ||
from float8_experimental.float8_tensor import ( | ||
calculate_amax_and_cast_to_float8, | ||
Float8Tensor, | ||
) | ||
|
||
from float8_experimental.float8_utils import ( | ||
amax_history_to_scale, | ||
|
@@ -172,6 +177,15 @@ def __init__(self, *args, **kwargs): | |
# will access the scale when it has ensured that it is on GPU. | ||
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs) | ||
|
||
if config.allocate_float8_weight_cache_buffers: | ||
# this is a buffer to get `to(dtype)` for free | ||
# TODO(future): hide this from serialization | ||
# TODO(future): force this to stay in float8_e4m3fn | ||
self.register_buffer( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could probably use something like this for the second todo There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, we can fix that in a future PR! |
||
"cached_fp8_weight", | ||
torch.empty(self.weight.shape, dtype=torch.float8_e4m3fn), | ||
) | ||
|
||
def register_always_float32_buffer( | ||
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True | ||
) -> None: | ||
|
@@ -228,8 +242,33 @@ def cast_w_to_float8( | |
torch.float8_e4m3fn, | ||
is_amax_initialized, | ||
) | ||
|
||
if config.weight_cache_enabled: | ||
assert config.allocate_float8_weight_cache_buffers, ( | ||
"float8 weight cache buffer must be allocated using " | ||
+ "`allocate_float8_weight_cache_buffers` to use the weight cache" | ||
) | ||
w_bits_fp8 = self.cached_fp8_weight | ||
else: | ||
# manual calculation of fp8 bits: | ||
# 1. calculate the bits without Float8Tensor, without grad | ||
# 2. store the bits here | ||
# 3. create Float8Tensor from the bits calculated in 2 | ||
# motivation: this will take care of saving the bits without | ||
# interacting with tensor subclasses, as w_fp8._data is not | ||
# currently traceable by dynamo | ||
w_bits_fp8 = calculate_amax_and_cast_to_float8( | ||
self.weight, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w | ||
) | ||
if config.allocate_float8_weight_cache_buffers: | ||
self.cached_fp8_weight.copy_(w_bits_fp8) | ||
w_fp8 = Float8Tensor.to_float8( | ||
w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w, self.emulate | ||
w, | ||
self.fp8_scale_w, | ||
torch.float8_e4m3fn, | ||
self.fp8_amax_w, | ||
self.emulate, | ||
cached_casted_weight=w_bits_fp8, | ||
) | ||
return w_fp8 | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,16 @@ | |
aten = torch.ops.aten | ||
|
||
|
||
@torch.no_grad() | ||
def calculate_amax_and_cast_to_float8(tensor, scale, float8_dtype, amax_buffer): | ||
if amax_buffer is not None: | ||
amax_buffer.fill_(tensor_to_amax(tensor)) | ||
|
||
tensor_scaled = tensor * scale | ||
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) | ||
return bits_fp8 | ||
|
||
|
||
class ToFloat8ConstrFunc(torch.autograd.Function): | ||
""" | ||
A differentiable conversion to fp8 | ||
|
@@ -25,24 +35,23 @@ def forward( | |
float8_dtype=torch.float8_e4m3fn, | ||
amax_buffer=None, | ||
emulate: bool = False, | ||
cached_casted_weight=None, | ||
): | ||
# In TransformerEngine, the casts to float8 are fused with calculating | ||
# the new amax value. In this codebase, the eager mode code for those | ||
# two things is colocated in this function. We expect PT2.0 to fuse it | ||
# for us. | ||
if amax_buffer is not None: | ||
amax_buffer.fill_(tensor_to_amax(tensor)) | ||
|
||
tensor_scaled = tensor * scale | ||
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) | ||
if cached_casted_weight is not None: | ||
return Float8Tensor( | ||
cached_casted_weight, scale, tensor.dtype, emulate=emulate | ||
) | ||
bits_fp8 = calculate_amax_and_cast_to_float8( | ||
tensor, scale, float8_dtype, amax_buffer | ||
) | ||
return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate) | ||
|
||
@staticmethod | ||
def backward(ctx, g): | ||
if isinstance(g, Float8Tensor): | ||
return g.to_original_precision(), None, None, None, None | ||
return g.to_original_precision(), None, None, None, None, None | ||
else: | ||
return g, None, None, None, None | ||
return g, None, None, None, None, None | ||
|
||
|
||
class FromFloat8ConstrFunc(torch.autograd.Function): | ||
|
@@ -122,7 +131,7 @@ def __tensor_flatten__(self): | |
return ["_data", "_scale"], ctx | ||
|
||
@staticmethod | ||
def __tensor_unflatten__(inner_tensors: Dict, metadata): | ||
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not related to this PR, but fixing it to adhere to the changes in https://github.com/pytorch/pytorch/pull/114311/files |
||
assert len(inner_tensors) == 2 | ||
return Float8Tensor( | ||
inner_tensors["_data"], | ||
|
@@ -136,7 +145,14 @@ def to_original_precision(self): | |
|
||
@staticmethod | ||
@torch._dynamo.allow_in_graph | ||
def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False): | ||
def to_float8( | ||
tensor, | ||
scale, | ||
float8_dtype, | ||
amax_buffer=None, | ||
emulate: bool = False, | ||
cached_casted_weight=None, | ||
): | ||
"""Converts a higher precision tensor to float8 in a differentiable way. | ||
|
||
Args: | ||
|
@@ -149,7 +165,12 @@ def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = Fal | |
Float8Tensor: a float8 tensor | ||
""" | ||
return ToFloat8ConstrFunc.apply( | ||
tensor, scale, float8_dtype, amax_buffer, emulate | ||
tensor, | ||
scale, | ||
float8_dtype, | ||
amax_buffer, | ||
emulate, | ||
cached_casted_weight, | ||
) | ||
|
||
@classmethod | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some more comments on how users are expected to use this in code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added to readme