Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

support float8 weight caching for gradient accumulation/PP #164

Closed
wants to merge 1 commit into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Dec 17, 2023

Summary:

In the cases where the optimizer update does not happen after every forward such as microbatching/PP, we can save the casted weight to trade some time for memory.

For now I'm just testing out performance+accuracy. We can improve on the API in future PRs. The current code is torch.compile friendly which is nice.

In terms of accuracy this should be no change, I will validate this further if we want to land this.

For performance, on @drisspg's LLaMa 7B pretrain script, with bsz==128 and micro_bsz == 1:

  1. baseline bf16 + compile: 2.38 it/s
  2. delayed scaling + compile: 2.80 it/s (1.18x over baseline)
  3. delayed scaling + compile + this PR: 3.04 it/s (1.28x over baseline)

Test Plan:

pytest test/test_base.py -s -k test_weight_caching

Reviewers:

Subscribers:

Tasks:

Tags:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 17, 2023
@vkuzo vkuzo force-pushed the 20231215_weight_caching branch from e4b126a to f506964 Compare December 17, 2023 23:45
@@ -123,6 +132,9 @@ def __tensor_flatten__(self):

@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, metadata):
# TODO(TBD): this seems unused, and it's out of date after
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking into this today and I do think that this is used it is just that in full graph compile land we get away with this not being a problem we should update this to the new signature

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, do you have more context?

If I understand the current state correctly:

  1. the currently checked in code is broken because the signature of tensor_unflatten is no longer valid after Expand dynamic dims support for traceable subclasses pytorch/pytorch#114311
  2. if this code path is exercised by dynamo, it will crash because of (1)
  3. all tests pass, so I'm guessing this is dead code at the moment

Copy link
Contributor

@y-sq y-sq Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running the test of multi-gpu with compile (FSDP(torch.compile(model))), I did get some errors around __tensor_unflatten__, P913908807. Might be related to the discussion?
And the model couldn't be compiled with fullgraph=True when wrapped by FSDP.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh I imagine that although this is broken we are not encountering this because we can full graph compile and the tensor subclass vanishes from aot, is that correct or am I missing something? So although this is not exercised today if we decided to expose float8tensor subclass as more of a first ux and they entered compiled functions we would need this again

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep - I think that if your subclass is just a temporary that poofs during tracing (so not a graph input or output), then we won't end up hitting the flatten/unflatten calls at all. We'll definitely need it though if/when the subclass ever becomes a graph input/output.

This can happen if you make the subclass a module parameter directly, but it could also just happen via graph breaks, if there happens to be a graph break after we create the float8 tensor but before it vanishes (which to be fair would probably be bad for the level of perf that float8 wants anyway).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put up #166, it doesn't work yet, @bdhirsh any thoughts on the error message in that PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo is the error message linked somewhere? (I checked the PR but I just see a ufmt lint failure in the CI)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep its in the test plan (not in CI)!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missed it the first time https://gist.github.com/vkuzo/ba98a01a459fb9c966f167d8ecca1780#file-gistfile0-txt-L137.

Hmm, this looks like the same error @drisspg and I noticed a month or two ago, where (somehow) we end up with multiple fake modes floating around. This probably happens somewhere in dynamo, since dynamo is responsible for creating the FakeTensorMode.

This bug might be pretty involved, so just checking - "get float8tensor working with graph breaks" a blocker? (Regardless I agree that we should make it work)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not technically a blocker, but would be really nice for it to work

# this is a buffer to get `to(dtype)` for free
# TODO(future): hide this from serialization
# TODO(future): force this to stay in float8_e4m3fn
self.register_buffer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, we can fix that in a future PR!

# LICENSE file in the root directory of this source tree.

# If True, allocates buffers for float8 weight cache
allocate_float8_weight_cache_buffers = False
Copy link
Contributor

@drisspg drisspg Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some more comments on how users are expected to use this in code

Sync_float_amax()
if accumulate_grad:
   weight_cache_enabled = Trie

If not accumulate grad:
   optimixer.step()
   weight_cache_enabled=False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added to readme

@@ -204,8 +218,30 @@ def cast_w_to_float8(
torch.float8_e4m3fn,
is_amax_initialized,
)

if config.weight_cache_enabled:
assert config.allocate_float8_weight_cache_buffers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe a better error message, "if you are using weight caching you need to enable allocate_float8_weight_cache_buffers = False before performing module swap or construction of Float8Linear module"

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, mostly around documentation but seems good

@vkuzo vkuzo force-pushed the 20231215_weight_caching branch from f506964 to ff61a58 Compare December 21, 2023 06:06
@vkuzo vkuzo changed the title [wip] support float8 weight caching for gradient accumulation/PP support float8 weight caching for gradient accumulation/PP Dec 21, 2023
@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Summary:

In the cases where the optimizer update does not happen after every
forward such as microbatching/PP, we can save the casted
weight to trade some time for memory.

For now I'm just testing out performance+accuracy. We can improve on the
API in future PRs.

In terms of accuracy this should be no change, I will validate this
further if we want to land this.

For performance, on @drisspg's LLaMa 7B pretrain script, with bsz==128
and micro_bsz == 1:

1. baseline bf16 + compile: 2.38 it/s
2. delayed scaling + compile: 2.80 it/s (1.18x over baseline)
3. delayed scaling + compile + this PR: 3.04 it/s (1.28x over baseline)

Test Plan:

```
pytest test/test_base.py -s -k test_weight_caching
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo force-pushed the 20231215_weight_caching branch from ff61a58 to 96f7ccc Compare December 21, 2023 06:10
@@ -122,7 +131,7 @@ def __tensor_flatten__(self):
return ["_data", "_scale"], ctx

@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, metadata):
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but fixing it to adhere to the changes in https://github.com/pytorch/pytorch/pull/114311/files

@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@vkuzo merged this pull request in f4812ee.

vkuzo added a commit that referenced this pull request Jan 10, 2024
Summary:

Removes most of
#164

This isn't useful in the short term since it doesn't compose
with FSDP + compile, and memory overhead is high.  We can bring it back later if needed.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo mentioned this pull request Jan 10, 2024
facebook-github-bot pushed a commit that referenced this pull request Jan 10, 2024
Summary:
Removes most of
#164

This isn't useful in the short term since it doesn't compose with FSDP + compile, and memory overhead is high.  We can bring it back later if needed.

Pull Request resolved: #181

Test Plan:
```
./test/test_everything.sh
```

Reviewed By: drisspg

Differential Revision: D52648603

Pulled By: vkuzo

fbshipit-source-id: f956337264fd28fa0bc50d151c316cde7c3d28de
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants