Skip to content

Use joint trace in transform_for_execution #2102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

Conversation

beverlylytle
Copy link
Collaborator

@beverlylytle beverlylytle commented May 20, 2025

This PR aims to use a joint forward-backward trace in transform_for_execution while jitting, instead of separately processing a forward trace and a backward trace. This change is behind the compile option flag delay_trace_split, which currently defaults to True. Provided no performance or memory issues appear, this will allow for a follow-up PR which can remove the flag and delete ~300 lines from torch_autograd.py and ~300 lines from rematerialization.py along with the relevant tests.

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In thunder/tests/distributed/test_ddp.py it's okay to skip the test grad bucketing in Thunder's DDP. I only get two failures with TransformerEngine (the same error as in #2060):

FAILED thunder/tests/distributed/test_ddp.py::test_ddp_transformer_engine_torch_cuda_thunder.dtypes.float32 - ValueError: not enough values to unpack (expected 5, got 4)
FAILED thunder/tests/distributed/test_ddp.py::test_ddp_transformer_engine_llama_sanity_torch_cuda_thunder.dtypes.float32 - ValueError: not enough values to unpack (expected 5, got 4)

Similarly, in thunder/tests/distributed/test_fsdp.py it's okay to skip failing tests for Thunder's FSDP bucketing and no_sync.

assert len(backward_execution_trace.bound_symbols) == 14
assert len(backward_execution_trace.bound_symbols) == 17
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New bound symbols here are for clearing the C1 collection from the saved_for_backward input:

-  C0, _, = saved_for_backward
-  # C0: "Collection"
-  # None
+  C0, C1, = saved_for_backward
+  # C0: "Collection"
+  # C1: "Collection"
...
+  # C1 (empty sequence)
+  clear_mutable_collection(C1)
+  del C1

How is the C1 output from saved_for_backward unpacking removed on the main branch? Is it a DCE pass somewhere?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assert change can be reverted with

diff --git a/thunder/transforms/autodiff.py b/thunder/transforms/autodiff.py
index 131e666a..700607e4 100644
--- a/thunder/transforms/autodiff.py
+++ b/thunder/transforms/autodiff.py
@@ -3,6 +3,7 @@ from thunder.core.transforms import ForwardBackwardTraces
 
 from thunder.core import prims, utils
 from thunder.core.transforms import (
+    dce,
     is_constant_for_vjp,
     _get_gradfn_and_executor,
     augmented_forward_impls,
@@ -529,6 +530,7 @@ def split_into_forward_and_backward(joint_trace):
     with thunder.core.trace.tracectx(backward_trace):
         prims.python_return(tuple(return_bsym.args[0]["grad_flat_args"]))
 
+    backward_trace = dce(backward_trace)
     return forward_trace, backward_trace

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in a27f390.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If all these tests are being xfailed, do you want that the compile option flag delay_trace_split continues to default to true?

@IvanYashchuk IvanYashchuk requested a review from Copilot May 27, 2025 12:30
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This pull request updates several components of the autodiff and distributed transforms along with modifications to the executor implementations and corresponding tests to support new behaviors in the CI. Key changes include:

  • Introducing a new utility function (_group_get_grad_bsyms) in the autodiff transform.
  • Adjusting test expectations and xfail markers for distributed traces.
  • Refining conditionals in passes and executors to appropriately handle get_grad operations.

Reviewed Changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
thunder/transforms/autodiff.py Added _group_get_grad_bsyms and updated gradient grouping logic.
thunder/tests/test_examine_memory.py Updated test expectations for memory estimates.
thunder/tests/distributed/test_fsdp.py Updated unshard parameter names and corrected trace index usage.
thunder/tests/distributed/test_ddp.py Added xfail marker for grad bucketing test.
thunder/executors/torchex.py Minor whitespace addition before shallow_copy registration.
thunder/executors/torch_compile.py Excluded GET_GRAD from implementation mapping in executor.
thunder/executors/torch_autograd.py Early return if bw_trace is None added.
thunder/executors/passes.py Extended condition to pass through GET_GRAD symbols.
thunder/core/transform_common.py Skipping further processing for GET_GRAD symbols now.
thunder/core/rematerialization.py Enhanced filtering of parameter names during rematerialization.
thunder/init.py Revised trace-split logic under the delay_trace_split branch.

computation_traces.extend(extraces)
computation_trc = computation_traces[-1]
computation_trc = thunder.executors.passes.del_last_used(computation_trc)
computation_trc = extraces[-1]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the joint traces are not being appended here, and that is also delayed until after the splitting. I'm unsure whether this is desired behavior or not.

Comment on lines 343 to 345
# group get_grad symbols together for torch compile fusions
# !!! is it preferable to do this here or in the torch compile fusion pass?
_group_get_grad_bsyms(trace)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was running into trouble with the get_grad bsyms breaking fusion regions for the torch compile executor. By grouping all the get_grad together with all the backward bsyms following, this allowed the same fusion regions as with the separate traces. Is this the "Right Way" to do this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the trace looks like when it breaks the fusions? Can this be because get_grad output appears not to have any dependency?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trace looks like:

<get_grad>
<fusible_op>
<fusible_op>
<get_grad>
<fusible_op>
<fusible_op>
<cat_op>

When torch compile fusion pass encounters a cat op, it makes a fusion region of all the fusible ops around it. get_grad isn't (and shouldn't be) fusible. So, to make a maximal fusion region, I rearrange the get_grad ops so that the trace looks like

<get_grad>
<get_grad>
<fusible_op>
<fusible_op>
<fusible_op>
<fusible_op>
<cat_op>

I don't believe that get_grad output dependencies are at play here.

@lantiga
Copy link
Collaborator

lantiga commented May 27, 2025

@IvanYashchuk about skipping bucketing with ddp and fsdp, we are actually counting on that to work properly for our distributed work. Do you think this can be tackled on your end?

@lantiga
Copy link
Collaborator

lantiga commented May 27, 2025

Thanks for the clarification @IvanYashchuk. I agree we can forego bucketing for now and eventually circle back to it at a later stage.

@beverlylytle beverlylytle changed the title [WIP2] Use joint trace in transform_for_execution May 28, 2025
@beverlylytle beverlylytle marked this pull request as ready for review May 28, 2025 07:40
@beverlylytle beverlylytle mentioned this pull request May 28, 2025
4 tasks
Comment on lines +289 to +291
if bsym.sym.id == prims.PrimIDs.GET_GRAD:
return new_bsym

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting addition here, what is it needed for? Might #2143 solve this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider a function that looks like x -> sin(x), sin(x). In the fusion passes, there is cse which causes the two get_grad bsyms to be smashed into one, and a trace that looks like

import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cuda:0 f64[2, 3]"
  [t2, t6] = nvFusion0(a)
    # t2 = prims.sin(a)  # t2: "cuda:0 f64[2, 3]"
    # t6 = prims.sin(a)  # t6: "cuda:0 f64[2, 3]"
  t13 = prims.get_grad(t6)  # t13: "cuda:0 f64[2, 3]"
  nvFusion1(a, t13)
    # t8 = prims.cos(a)  # t8: "cuda:0 f64[2, 3]"
    # t9 = prims.mul(t13, t8)  # t9: "cuda:0 f64[2, 3]"
  t14 = prims.get_grad(t2)  # t14: "cuda:0 f64[2, 3]"
  [t10] = nvFusion2(a, t14, t13)
    # t4 = prims.cos(a)  # t4: "cuda:0 f64[2, 3]"
    # t8 = prims.cos(a)  # t8: "cuda:0 f64[2, 3]"
    # t5 = prims.mul(t14, t4)  # t5: "cuda:0 f64[2, 3]"
    # t9 = prims.mul(t13, t8)  # t9: "cuda:0 f64[2, 3]"
    # t10 = prims.add(t9, t5)  # t10: "cuda:0 f64[2, 3]"
  return {'output': (t2, t6), 'flat_args': [a], 'grad_flat_args': [t10]}

turns into

import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cuda:0 f64[2, 3]"
  [t2] = nvFusion0(a)
    # t2 = prims.sin(a)  # t2: "cuda:0 f64[2, 3]"
  t15 = prims.get_grad(t2)  # t15: "cuda:0 f64[2, 3]"
  [t10] = nvFusion2(a, t15)
    # t4 = prims.cos(a)  # t4: "cuda:0 f64[2, 3]"
    # t5 = prims.mul(t15, t4)  # t5: "cuda:0 f64[2, 3]"
    # t10 = prims.add(t5, t5)  # t10: "cuda:0 f64[2, 3]"
  return {'output': (t2, t2), 'flat_args': [a], 'grad_flat_args': [t10]}

after the fusion's cse pass. This is incorrect as the backward should have two distinct cotangents t14 and t13, and not just the one t15. Unfortunately your suggested changes don't help this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh so it looks like this before dce:

  nvFusion1(a, t13)
    # t8 = prims.cos(a)  # t8: "cuda:0 f64[2, 3]"
    # t9 = prims.mul(t13, t8)  # t9: "cuda:0 f64[2, 3]"

So there has been some dce going on there earlier. The fact that it has no output to me suggests that it suffers from #2132 or similar issue hmmm interesting

Copy link
Collaborator Author

@beverlylytle beverlylytle May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, the lack of output comes from the rematerialization that happened just before:

if self._use_rematerialization:

Comment on lines 343 to 345
# group get_grad symbols together for torch compile fusions
# !!! is it preferable to do this here or in the torch compile fusion pass?
_group_get_grad_bsyms(trace)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the trace looks like when it breaks the fusions? Can this be because get_grad output appears not to have any dependency?

@beverlylytle
Copy link
Collaborator Author

beverlylytle commented Jun 3, 2025

I've benchmarked this change with the following baseline and result:

@main
Model name: Llama-3-8B
Seq Length: 8192
Micro BS: 1
Global BS: 8
Number of Layers: 32
Number of parameters: 1.00B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: block
Compiler: dynamo_thunder
Low Precision Mode: none
Average iter time: 782.73 ms
Memory used: 72.61 GB
Tokens/s: 83707.48
Tokens/s/GPU: 10463.43
TFLOP/s: 4847.74

@reautograd2
Model name: Llama-3-8B
Seq Length: 8192
Micro BS: 1
Global BS: 8
Number of Layers: 32
Number of parameters: 1.00B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: block
Compiler: dynamo_thunder
Low Precision Mode: none
Average iter time: 781.49 ms
Memory used: 72.61 GB
Tokens/s: 83855.93
Tokens/s/GPU: 10481.99
TFLOP/s: 4856.34

@t-vi
Copy link
Collaborator

t-vi commented Jun 3, 2025

Does it work with the thunder.jit? Could we also benchmark with "thunder" as compiler?

@beverlylytle
Copy link
Collaborator Author

beverlylytle commented Jun 3, 2025

@main
Model name: Llama-3-8B
Seq Length: 8192
Micro BS: 1
Global BS: 8
Number of Layers: 32
Number of parameters: 1.00B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: none
Compiler: thunder
Low Precision Mode: none
Average iter time: 799.59 ms
Memory used: 75.75 GB
Saved for backward size: 58448.60 MiB
Saved for backward number of tensors: 775
Tokens/s: 81988.63
Tokens/s/GPU: 10248.58
TFLOP/s: 4748.20

@retrograd2
Model name: Llama-3-8B
Seq Length: 8192
Micro BS: 1
Global BS: 8
Number of Layers: 32
Number of parameters: 1.00B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: none
Compiler: thunder
Low Precision Mode: none
Average iter time: 808.52 ms
Memory used: 71.53 GB
Saved for backward size: 62988.09 MiB
Saved for backward number of tensors: 712
Tokens/s: 81032.54
Tokens/s/GPU: 10129.07
TFLOP/s: 4692.83

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants