-
Notifications
You must be signed in to change notification settings - Fork 96
TE: Fix redundant compute for PEFT using transform #2138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
…ightning-thunder into te-frozen-weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice fix, however we cannot rely on the assumption that requires_grad is always propagated throughout the trace and I think we should move away from that assumption unless we make sure that propagation is always guaranteed.
A more involved alternative would be to pickup on the runtime proxy idea
with torch.device("cuda"): | ||
model = torch.nn.Sequential(*(torch.nn.Linear(32, 32, bias=False) for _ in range(6))) | ||
x = torch.randn(32, 32, requires_grad=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to try it with input and outs with different sizes? this risks triggering cacheing issue #1905 something like this would have inputs of different sizes to test if requires_grad is propagated and set correctly:
with torch.device("cuda"): | |
model = torch.nn.Sequential(*(torch.nn.Linear(32, 32, bias=False) for _ in range(6))) | |
x = torch.randn(32, 32, requires_grad=True) | |
with torch.device("cuda"): | |
model = torch.nn.Sequential( | |
torch.nn.Linear(32, 64, bias=False), | |
torch.nn.Linear(64, 128, bias=False), | |
torch.nn.Linear(128, 64, bias=False), | |
torch.nn.Linear(64, 32, bias=False), | |
) | |
x = torch.randn(32, 32, requires_grad=True) |
dgrad, wgrad, bgrad = bsym.output | ||
w_requires_grad = True if wgrad is not None else False | ||
b_requires_grad = True if bgrad is not None else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting hack for requires_grad propagation, tho if the symbol before the one captured by TE executor did not propagate requires_grad this might not work as intended
This fix doesn't rely on lightning-thunder/thunder/executors/torch_autograd.py Lines 271 to 280 in 1ef2b94
If the gradient is not returned from the backward trace, then we just update both forward and backward trace so that we don't save FP8 copy for backward and wgrad is not computed respectively.
As far as I can tell RuntimeProxy idea #1599, will just ban us to fetch requires_grad from intermediate TensorProxy. However, it won't fix the problem of correctly propagating it #1768. I could be wrong though cc: @IvanYashchuk as author of #1599 to clarify. |
Fixes: #2076
Use a pass based on the
backward
trace to determine ifwgrad
andbgrad
are computed and updateforward
trace accordingly.Have added a test for the same (and verified existing tests single-GPU and distributed). Tested with pjnl-20250505 (didn't check with latest due to #2060)
NOTE:
Example Program:
Forward Trace
Backward Trace