Skip to content

TE: Fix redundant compute for PEFT using transform #2138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented May 26, 2025

Fixes: #2076

Use a pass based on the backward trace to determine if wgrad and bgrad are computed and update forward trace accordingly.

Have added a test for the same (and verified existing tests single-GPU and distributed). Tested with pjnl-20250505 (didn't check with latest due to #2060)

NOTE:

Example Program:

    with torch.device("cuda"):
        model = torch.nn.Sequential(*(torch.nn.Linear(32, 32, bias=False) for _ in range(4)))
        x = torch.randn(32, 32, requires_grad=True)

    for idx, parameters in enumerate(model.parameters()):
        # Every even linear layer's weight is frozen.
        if idx % 2 == 0:
            parameters.requires_grad = False

Forward Trace

@transformer_engine.fp8_autocast(fp8_recipe=te_fp8_recipe)
@torch.no_grad()
@no_autocast
def computation(input, t_0_weight, t_1_weight, t_2_weight, t_3_weight):
  # input: "cuda:0 f32[32, 32]"
  # t_0_weight: "cuda:0 f32[32, 32]"
  # t_1_weight: "cuda:0 f32[32, 32]"
  # t_2_weight: "cuda:0 f32[32, 32]"
  # t_3_weight: "cuda:0 f32[32, 32]"

  # /usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py:125: 	        return F.linear(input, self.weight, self.bias)
  (t24, (t19, t20, t21, t22, t23), ctx_te_1118) = te_linear_11(input, t_0_weight, None, False, False)
  (t36, (t31, t32, t33, t34, t35), ctx_te_1230) = te_linear_12(t24, t_1_weight, None, True, False)
  del t24

  # /usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py:125: 	        return F.linear(input, self.weight, self.bias)
  (t48, (t43, t44, t45, t46, t47), ctx_te_1342) = te_linear_13(t36, t_2_weight, None, False, False)
  del t36

  # /usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py:125: 	        return F.linear(input, self.weight, self.bias)
  (t60, (t55, t56, t57, t58, t59), ctx_te_1454) = te_linear_14(t48, t_3_weight, None, True, False)
  del t48
  return {'output': (t60,), 'flat_args': [input, t_0_weight, t_1_weight, t_2_weight, t_3_weight], 'flat_output': (t60,)}, ((t19, t20, t21, t22, t23, t31, t32, t33, t34, t35, t43, t44, t45, t46, t47, t55, t56, t57, t58, t59), (ctx_te_1118, ctx_te_1230, ctx_te_1342, ctx_te_1454))

Backward Trace

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  # C0: "Collection"
  # C1: "Collection"
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t61, = cotangents
  # t61: "cuda:0 f32[32, 32]"
  clear_mutable_collection(cotangents)
  del cotangents
  t19, t20, t21, t22, t23, t31, t32, t33, t34, t35, t43, t44, t45, t46, t47, t55, \
  t56, t57, t58, t59, = C0
 
  clear_mutable_collection(C0)
  del C0
  ctx_te_1118, ctx_te_1230, ctx_te_1342, ctx_te_1454, = C1
  # ctx_te_1118: "<class 'object'>"
  # ctx_te_1230: "<class 'object'>"
  # ctx_te_1342: "<class 'object'>"
  # ctx_te_1454: "<class 'object'>"
  clear_mutable_collection(C1)
  del C1
  (bw_t62, grad_for_t_3_weight, _) = te_functional_linear_backward((32, 32), (32, 32), None, True, False, ctx_te_1454, (t55, t56, t57, t58, t59), t61)
  del ctx_te_1454, t55, t56, t57, t58, t59, t61
  (bw_t50, _, _) = te_functional_linear_backward((32, 32), (32, 32), None, False, False, ctx_te_1342, (t43, t44, t45, t46, t47), bw_t62)
  del ctx_te_1342, t43, t44, t45, t46, t47, bw_t62
  (bw_t38, grad_for_t_1_weight, _) = te_functional_linear_backward((32, 32), (32, 32), None, True, False, ctx_te_1230, (t31, t32, t33, t34, t35), bw_t50)
  del ctx_te_1230, t31, t32, t33, t34, t35, bw_t50
  (grad_for_input, _, _) = te_functional_linear_backward((32, 32), (32, 32), None, False, False, ctx_te_1118, (t19, t20, t21, t22, t23), bw_t38)
  del ctx_te_1118, t19, t20, t21, t22, t23, bw_t38
  te_sync_fp8_meta_bwd()
  return (grad_for_input, None, grad_for_t_1_weight, None, grad_for_t_3_weight)

@kshitij12345 kshitij12345 marked this pull request as draft May 26, 2025 13:25
Copy link
Collaborator

@riccardofelluga riccardofelluga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice fix, however we cannot rely on the assumption that requires_grad is always propagated throughout the trace and I think we should move away from that assumption unless we make sure that propagation is always guaranteed.

A more involved alternative would be to pickup on the runtime proxy idea

Comment on lines +268 to +270
with torch.device("cuda"):
model = torch.nn.Sequential(*(torch.nn.Linear(32, 32, bias=False) for _ in range(6)))
x = torch.randn(32, 32, requires_grad=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to try it with input and outs with different sizes? this risks triggering cacheing issue #1905 something like this would have inputs of different sizes to test if requires_grad is propagated and set correctly:

Suggested change
with torch.device("cuda"):
model = torch.nn.Sequential(*(torch.nn.Linear(32, 32, bias=False) for _ in range(6)))
x = torch.randn(32, 32, requires_grad=True)
with torch.device("cuda"):
model = torch.nn.Sequential(
torch.nn.Linear(32, 64, bias=False),
torch.nn.Linear(64, 128, bias=False),
torch.nn.Linear(128, 64, bias=False),
torch.nn.Linear(64, 32, bias=False),
)
x = torch.randn(32, 32, requires_grad=True)

Comment on lines +599 to +601
dgrad, wgrad, bgrad = bsym.output
w_requires_grad = True if wgrad is not None else False
b_requires_grad = True if bgrad is not None else False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting hack for requires_grad propagation, tho if the symbol before the one captured by TE executor did not propagate requires_grad this might not work as intended

@kshitij12345
Copy link
Collaborator Author

This is a nice fix, however we cannot rely on the assumption that requires_grad is always propagated throughout the trace and I think we should move away from that assumption unless we make sure that propagation is always guaranteed.

This fix doesn't rely on requires_grad being propagated correctly but on whether or not the gradient was returned from backward trace.

# Update the backward trace to only compute gradients for the
# inputs that require gradients
assert bw_trace.bound_symbols[-1].sym.id == PrimIDs.RETURN
filtered_grads = tuple(
(arg_grad if requires_grad else None)
for arg_grad, requires_grad in utils.safe_zip(bw_trace.bound_symbols[-1].args[0], requires_grad_mask)
)
# autograd.Function.backward expects a flat tuple of gradients
bw_trace.bound_symbols[-1] = replace(bw_trace.bound_symbols[-1], args=(filtered_grads,))

If the gradient is not returned from the backward trace, then we just update both forward and backward trace so that we don't save FP8 copy for backward and wgrad is not computed respectively.

A more involved alternative would be to pickup on the runtime proxy idea

As far as I can tell RuntimeProxy idea #1599, will just ban us to fetch requires_grad from intermediate TensorProxy. However, it won't fix the problem of correctly propagating it #1768. I could be wrong though cc: @IvanYashchuk as author of #1599 to clarify.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TE: Redundant backward computation in PEFT setting.
2 participants