-
Notifications
You must be signed in to change notification settings - Fork 96
Use joint trace in transform_for_execution #2102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In thunder/tests/distributed/test_ddp.py
it's okay to skip the test grad bucketing in Thunder's DDP. I only get two failures with TransformerEngine (the same error as in #2060):
FAILED thunder/tests/distributed/test_ddp.py::test_ddp_transformer_engine_torch_cuda_thunder.dtypes.float32 - ValueError: not enough values to unpack (expected 5, got 4)
FAILED thunder/tests/distributed/test_ddp.py::test_ddp_transformer_engine_llama_sanity_torch_cuda_thunder.dtypes.float32 - ValueError: not enough values to unpack (expected 5, got 4)
Similarly, in thunder/tests/distributed/test_fsdp.py
it's okay to skip failing tests for Thunder's FSDP bucketing and no_sync.
assert len(backward_execution_trace.bound_symbols) == 14 | ||
assert len(backward_execution_trace.bound_symbols) == 17 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New bound symbols here are for clearing the C1 collection from the saved_for_backward
input:
- C0, _, = saved_for_backward
- # C0: "Collection"
- # None
+ C0, C1, = saved_for_backward
+ # C0: "Collection"
+ # C1: "Collection"
...
+ # C1 (empty sequence)
+ clear_mutable_collection(C1)
+ del C1
How is the C1 output from saved_for_backward unpacking removed on the main branch? Is it a DCE pass somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assert change can be reverted with
diff --git a/thunder/transforms/autodiff.py b/thunder/transforms/autodiff.py
index 131e666a..700607e4 100644
--- a/thunder/transforms/autodiff.py
+++ b/thunder/transforms/autodiff.py
@@ -3,6 +3,7 @@ from thunder.core.transforms import ForwardBackwardTraces
from thunder.core import prims, utils
from thunder.core.transforms import (
+ dce,
is_constant_for_vjp,
_get_gradfn_and_executor,
augmented_forward_impls,
@@ -529,6 +530,7 @@ def split_into_forward_and_backward(joint_trace):
with thunder.core.trace.tracectx(backward_trace):
prims.python_return(tuple(return_bsym.args[0]["grad_flat_args"]))
+ backward_trace = dce(backward_trace)
return forward_trace, backward_trace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in a27f390.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If all these tests are being xfailed, do you want that the compile option flag delay_trace_split continues to default to true?
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request updates several components of the autodiff and distributed transforms along with modifications to the executor implementations and corresponding tests to support new behaviors in the CI. Key changes include:
- Introducing a new utility function (_group_get_grad_bsyms) in the autodiff transform.
- Adjusting test expectations and xfail markers for distributed traces.
- Refining conditionals in passes and executors to appropriately handle get_grad operations.
Reviewed Changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
thunder/transforms/autodiff.py | Added _group_get_grad_bsyms and updated gradient grouping logic. |
thunder/tests/test_examine_memory.py | Updated test expectations for memory estimates. |
thunder/tests/distributed/test_fsdp.py | Updated unshard parameter names and corrected trace index usage. |
thunder/tests/distributed/test_ddp.py | Added xfail marker for grad bucketing test. |
thunder/executors/torchex.py | Minor whitespace addition before shallow_copy registration. |
thunder/executors/torch_compile.py | Excluded GET_GRAD from implementation mapping in executor. |
thunder/executors/torch_autograd.py | Early return if bw_trace is None added. |
thunder/executors/passes.py | Extended condition to pass through GET_GRAD symbols. |
thunder/core/transform_common.py | Skipping further processing for GET_GRAD symbols now. |
thunder/core/rematerialization.py | Enhanced filtering of parameter names during rematerialization. |
thunder/init.py | Revised trace-split logic under the delay_trace_split branch. |
computation_traces.extend(extraces) | ||
computation_trc = computation_traces[-1] | ||
computation_trc = thunder.executors.passes.del_last_used(computation_trc) | ||
computation_trc = extraces[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the joint traces are not being appended here, and that is also delayed until after the splitting. I'm unsure whether this is desired behavior or not.
thunder/transforms/autodiff.py
Outdated
# group get_grad symbols together for torch compile fusions | ||
# !!! is it preferable to do this here or in the torch compile fusion pass? | ||
_group_get_grad_bsyms(trace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was running into trouble with the get_grad bsyms breaking fusion regions for the torch compile executor. By grouping all the get_grad together with all the backward bsyms following, this allowed the same fusion regions as with the separate traces. Is this the "Right Way" to do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the trace looks like when it breaks the fusions? Can this be because get_grad output appears not to have any dependency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The trace looks like:
<get_grad>
<fusible_op>
<fusible_op>
<get_grad>
<fusible_op>
<fusible_op>
<cat_op>
When torch compile fusion pass encounters a cat op, it makes a fusion region of all the fusible ops around it. get_grad isn't (and shouldn't be) fusible. So, to make a maximal fusion region, I rearrange the get_grad ops so that the trace looks like
<get_grad>
<get_grad>
<fusible_op>
<fusible_op>
<fusible_op>
<fusible_op>
<cat_op>
I don't believe that get_grad output dependencies are at play here.
@IvanYashchuk about skipping bucketing with ddp and fsdp, we are actually counting on that to work properly for our distributed work. Do you think this can be tackled on your end? |
Thanks for the clarification @IvanYashchuk. I agree we can forego bucketing for now and eventually circle back to it at a later stage. |
if bsym.sym.id == prims.PrimIDs.GET_GRAD: | ||
return new_bsym | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting addition here, what is it needed for? Might #2143 solve this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider a function that looks like x -> sin(x), sin(x). In the fusion passes, there is cse which causes the two get_grad bsyms to be smashed into one, and a trace that looks like
import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a):
# a: "cuda:0 f64[2, 3]"
[t2, t6] = nvFusion0(a)
# t2 = prims.sin(a) # t2: "cuda:0 f64[2, 3]"
# t6 = prims.sin(a) # t6: "cuda:0 f64[2, 3]"
t13 = prims.get_grad(t6) # t13: "cuda:0 f64[2, 3]"
nvFusion1(a, t13)
# t8 = prims.cos(a) # t8: "cuda:0 f64[2, 3]"
# t9 = prims.mul(t13, t8) # t9: "cuda:0 f64[2, 3]"
t14 = prims.get_grad(t2) # t14: "cuda:0 f64[2, 3]"
[t10] = nvFusion2(a, t14, t13)
# t4 = prims.cos(a) # t4: "cuda:0 f64[2, 3]"
# t8 = prims.cos(a) # t8: "cuda:0 f64[2, 3]"
# t5 = prims.mul(t14, t4) # t5: "cuda:0 f64[2, 3]"
# t9 = prims.mul(t13, t8) # t9: "cuda:0 f64[2, 3]"
# t10 = prims.add(t9, t5) # t10: "cuda:0 f64[2, 3]"
return {'output': (t2, t6), 'flat_args': [a], 'grad_flat_args': [t10]}
turns into
import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a):
# a: "cuda:0 f64[2, 3]"
[t2] = nvFusion0(a)
# t2 = prims.sin(a) # t2: "cuda:0 f64[2, 3]"
t15 = prims.get_grad(t2) # t15: "cuda:0 f64[2, 3]"
[t10] = nvFusion2(a, t15)
# t4 = prims.cos(a) # t4: "cuda:0 f64[2, 3]"
# t5 = prims.mul(t15, t4) # t5: "cuda:0 f64[2, 3]"
# t10 = prims.add(t5, t5) # t10: "cuda:0 f64[2, 3]"
return {'output': (t2, t2), 'flat_args': [a], 'grad_flat_args': [t10]}
after the fusion's cse pass. This is incorrect as the backward should have two distinct cotangents t14 and t13, and not just the one t15. Unfortunately your suggested changes don't help this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh so it looks like this before dce:
nvFusion1(a, t13)
# t8 = prims.cos(a) # t8: "cuda:0 f64[2, 3]"
# t9 = prims.mul(t13, t8) # t9: "cuda:0 f64[2, 3]"
So there has been some dce going on there earlier. The fact that it has no output to me suggests that it suffers from #2132 or similar issue hmmm interesting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, the lack of output comes from the rematerialization that happened just before:
if self._use_rematerialization: |
thunder/transforms/autodiff.py
Outdated
# group get_grad symbols together for torch compile fusions | ||
# !!! is it preferable to do this here or in the torch compile fusion pass? | ||
_group_get_grad_bsyms(trace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the trace looks like when it breaks the fusions? Can this be because get_grad output appears not to have any dependency?
I've benchmarked this change with the following baseline and result:
|
Does it work with the thunder.jit? Could we also benchmark with "thunder" as compiler? |
|
This PR aims to use a joint forward-backward trace in
transform_for_execution
while jitting, instead of separately processing a forward trace and a backward trace. This change is behind the compile option flagdelay_trace_split
, which currently defaults to True. Provided no performance or memory issues appear, this will allow for a follow-up PR which can remove the flag and delete ~300 lines from torch_autograd.py and ~300 lines from rematerialization.py along with the relevant tests.