-
Notifications
You must be signed in to change notification settings - Fork 19
support float8 weight caching for gradient accumulation/PP #164
Conversation
e4b126a
to
f506964
Compare
float8_experimental/float8_tensor.py
Outdated
@@ -123,6 +132,9 @@ def __tensor_flatten__(self): | |||
|
|||
@staticmethod | |||
def __tensor_unflatten__(inner_tensors: Dict, metadata): | |||
# TODO(TBD): this seems unused, and it's out of date after |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was looking into this today and I do think that this is used it is just that in full graph compile land we get away with this not being a problem we should update this to the new signature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, do you have more context?
If I understand the current state correctly:
- the currently checked in code is broken because the signature of tensor_unflatten is no longer valid after Expand dynamic dims support for traceable subclasses pytorch/pytorch#114311
- if this code path is exercised by dynamo, it will crash because of (1)
- all tests pass, so I'm guessing this is dead code at the moment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When running the test of multi-gpu with compile (FSDP(torch.compile(model))
), I did get some errors around __tensor_unflatten__
, P913908807. Might be related to the discussion?
And the model couldn't be compiled with fullgraph=True when wrapped by FSDP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh I imagine that although this is broken we are not encountering this because we can full graph compile and the tensor subclass vanishes from aot, is that correct or am I missing something? So although this is not exercised today if we decided to expose float8tensor subclass as more of a first ux and they entered compiled functions we would need this again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep - I think that if your subclass is just a temporary that poofs during tracing (so not a graph input or output), then we won't end up hitting the flatten/unflatten calls at all. We'll definitely need it though if/when the subclass ever becomes a graph input/output.
This can happen if you make the subclass a module parameter directly, but it could also just happen via graph breaks, if there happens to be a graph break after we create the float8 tensor but before it vanishes (which to be fair would probably be bad for the level of perf that float8 wants anyway).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo is the error message linked somewhere? (I checked the PR but I just see a ufmt lint failure in the CI)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep its in the test plan (not in CI)!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missed it the first time https://gist.github.com/vkuzo/ba98a01a459fb9c966f167d8ecca1780#file-gistfile0-txt-L137.
Hmm, this looks like the same error @drisspg and I noticed a month or two ago, where (somehow) we end up with multiple fake modes floating around. This probably happens somewhere in dynamo, since dynamo is responsible for creating the FakeTensorMode.
This bug might be pretty involved, so just checking - "get float8tensor working with graph breaks" a blocker? (Regardless I agree that we should make it work)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not technically a blocker, but would be really nice for it to work
# this is a buffer to get `to(dtype)` for free | ||
# TODO(future): hide this from serialization | ||
# TODO(future): force this to stay in float8_e4m3fn | ||
self.register_buffer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could probably use something like this for the second todo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, we can fix that in a future PR!
# LICENSE file in the root directory of this source tree. | ||
|
||
# If True, allocates buffers for float8 weight cache | ||
allocate_float8_weight_cache_buffers = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some more comments on how users are expected to use this in code
Sync_float_amax()
if accumulate_grad:
weight_cache_enabled = Trie
If not accumulate grad:
optimixer.step()
weight_cache_enabled=False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added to readme
float8_experimental/float8_linear.py
Outdated
@@ -204,8 +218,30 @@ def cast_w_to_float8( | |||
torch.float8_e4m3fn, | |||
is_amax_initialized, | |||
) | |||
|
|||
if config.weight_cache_enabled: | |||
assert config.allocate_float8_weight_cache_buffers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also maybe a better error message, "if you are using weight caching you need to enable allocate_float8_weight_cache_buffers = False before performing module swap or construction of Float8Linear module"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments, mostly around documentation but seems good
f506964
to
ff61a58
Compare
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary: In the cases where the optimizer update does not happen after every forward such as microbatching/PP, we can save the casted weight to trade some time for memory. For now I'm just testing out performance+accuracy. We can improve on the API in future PRs. In terms of accuracy this should be no change, I will validate this further if we want to land this. For performance, on @drisspg's LLaMa 7B pretrain script, with bsz==128 and micro_bsz == 1: 1. baseline bf16 + compile: 2.38 it/s 2. delayed scaling + compile: 2.80 it/s (1.18x over baseline) 3. delayed scaling + compile + this PR: 3.04 it/s (1.28x over baseline) Test Plan: ``` pytest test/test_base.py -s -k test_weight_caching ``` Reviewers: Subscribers: Tasks: Tags:
ff61a58
to
96f7ccc
Compare
@@ -122,7 +131,7 @@ def __tensor_flatten__(self): | |||
return ["_data", "_scale"], ctx | |||
|
|||
@staticmethod | |||
def __tensor_unflatten__(inner_tensors: Dict, metadata): | |||
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related to this PR, but fixing it to adhere to the changes in https://github.com/pytorch/pytorch/pull/114311/files
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary: Removes most of #164 This isn't useful in the short term since it doesn't compose with FSDP + compile, and memory overhead is high. We can bring it back later if needed. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
Summary: Removes most of #164 This isn't useful in the short term since it doesn't compose with FSDP + compile, and memory overhead is high. We can bring it back later if needed. Pull Request resolved: #181 Test Plan: ``` ./test/test_everything.sh ``` Reviewed By: drisspg Differential Revision: D52648603 Pulled By: vkuzo fbshipit-source-id: f956337264fd28fa0bc50d151c316cde7c3d28de
Summary:
In the cases where the optimizer update does not happen after every forward such as microbatching/PP, we can save the casted weight to trade some time for memory.
For now I'm just testing out performance+accuracy. We can improve on the API in future PRs. The current code is torch.compile friendly which is nice.
In terms of accuracy this should be no change, I will validate this further if we want to land this.
For performance, on @drisspg's LLaMa 7B pretrain script, with bsz==128 and micro_bsz == 1:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: