Skip to content

Commit a6f03c1

Browse files
committed
switch to save for backward since are now a tensor input
1 parent 338d87c commit a6f03c1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchao/dtypes/nf4tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -863,13 +863,13 @@ class LinearNF4(torch.autograd.Function):
863863
@staticmethod
864864
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
865865
"""Save the quantized nf4 weight for backward pass"""
866-
ctx.nf4_weight = weight
866+
ctx.save_for_backward(weight)
867867
return F.linear(input, weight.to(input.dtype))
868868

869869
@staticmethod
870870
def backward(ctx, grad_output):
871871
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)"""
872-
weight: NF4Tensor = ctx.nf4_weight
872+
weight: NF4Tensor = ctx.saved_tensors[0]
873873
return grad_output @ weight.to(grad_output.dtype), None
874874

875875

0 commit comments

Comments
 (0)