Skip to content

Commit 33f77d8

Browse files
andrewor14jainapurva
authored andcommitted
Add tutorial for trainable tensor subclass (#908)
Summary: The new tutorial provides an example of how to implement a trainable tensor subclass that wraps quantized data. This extends the existing `MyDTypeTensor` with a few necessary steps to ensure proper gradient updates, namely: 1. Define a differentiable constructor 2. Define backward pass for ops of interest (e.g. torch.nn.functional.linear) 3. Handle special ops used by the optimizer (e.g. aten.add, aten.add_) Test Plan: python tutorials/developer_api_guide/my_trainable_tensor_subclass.py
1 parent 8835ccc commit 33f77d8

File tree

3 files changed

+244
-36
lines changed

3 files changed

+244
-36
lines changed

tutorials/developer_api_guide/__init__.py

Whitespace-only changes.

tutorials/developer_api_guide/my_dtype_tensor_subclass.py

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __new__(
7777
layout_tensor: MyDTypeLayout,
7878
shape: torch.Size,
7979
dtype: Optional[torch.dtype] = None,
80+
requires_grad: bool = False,
8081
):
8182
kwargs = {}
8283
kwargs["device"] = layout_tensor.device
@@ -86,14 +87,15 @@ def __new__(
8687
else layout_tensor.layout
8788
)
8889
kwargs["dtype"] = dtype
89-
kwargs["requires_grad"] = False
90+
kwargs["requires_grad"] = requires_grad
9091
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
9192

9293
def __init__(
9394
self,
9495
layout_tensor: MyDTypeLayout,
9596
shape: torch.Size,
9697
dtype: Optional[torch.dtype] = None,
98+
requires_grad: bool = False,
9799
):
98100
self.layout_tensor = layout_tensor
99101

@@ -108,7 +110,7 @@ def __tensor_flatten__(self):
108110
The first one contains any tensor fields such as int_data and scale as keys to a dictionary
109111
The second one contains all other non tensor type fields as values of a list
110112
"""
111-
return ["layout_tensor"], [self.shape, self.dtype]
113+
return ["layout_tensor"], [self.shape, self.dtype, self.requires_grad]
112114

113115
@classmethod
114116
def __tensor_unflatten__(
@@ -120,11 +122,12 @@ def __tensor_unflatten__(
120122
tensor_attributes contains all other non tensor type fields
121123
"""
122124
layout_tensor = tensor_data_dict["layout_tensor"]
123-
shape, dtype = tensor_attributes
125+
shape, dtype, requires_grad = tensor_attributes
124126
return cls(
125127
layout_tensor,
126128
shape if outer_size is None else outer_size,
127129
dtype=dtype,
130+
requires_grad=requires_grad,
128131
)
129132

130133
"""classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype
@@ -330,37 +333,42 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
330333
########
331334
# Test #
332335
########
333-
from torchao.utils import benchmark_model
334-
335-
m = M()
336-
example_inputs = (100 * torch.randn(1024, 1024),)
337-
NUM_WARMUPS = 10
338-
NUM_RUNS = 100
339-
340-
for _ in range(NUM_WARMUPS):
341-
m(*example_inputs)
342-
print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs))
343-
344-
compiled = torch.compile(m, mode="max-autotune")
345-
for _ in range(NUM_WARMUPS):
346-
compiled(*example_inputs)
347-
print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs))
348-
349-
# convert weights to quantized weights
350-
m.linear.weight = torch.nn.Parameter(
351-
to_my_dtype(m.linear.weight), requires_grad=False
352-
)
353336

354-
for _ in range(NUM_WARMUPS):
355-
m(*example_inputs)
356-
357-
print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs))
358-
359-
m = torch.compile(m, mode="max-autotune")
360-
361-
for _ in range(NUM_WARMUPS):
362-
m(*example_inputs)
363-
364-
# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
365-
# we plan to add custom op example in the future and that will help us to get speedup
366-
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs))
337+
def test():
338+
from torchao.utils import benchmark_model
339+
340+
m = M()
341+
example_inputs = (100 * torch.randn(1024, 1024),)
342+
NUM_WARMUPS = 10
343+
NUM_RUNS = 100
344+
345+
for _ in range(NUM_WARMUPS):
346+
m(*example_inputs)
347+
print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs))
348+
349+
compiled = torch.compile(m, mode="max-autotune")
350+
for _ in range(NUM_WARMUPS):
351+
compiled(*example_inputs)
352+
print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs))
353+
354+
# convert weights to quantized weights
355+
m.linear.weight = torch.nn.Parameter(
356+
to_my_dtype(m.linear.weight), requires_grad=False
357+
)
358+
359+
for _ in range(NUM_WARMUPS):
360+
m(*example_inputs)
361+
362+
print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs))
363+
364+
m = torch.compile(m, mode="max-autotune")
365+
366+
for _ in range(NUM_WARMUPS):
367+
m(*example_inputs)
368+
369+
# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
370+
# we plan to add custom op example in the future and that will help us to get speedup
371+
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs))
372+
373+
if __name__ == "__main__":
374+
test()
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
"""
2+
This is an example for a tensor subclass representing a simple dtype
3+
that can be used in training.
4+
5+
We extend our previous example of `MyDTypeTensor` with a few extra steps
6+
needed to ensure proper gradient updates during training:
7+
8+
1. Define a differentiable constructor
9+
2. Define backward pass for ops of interest (e.g. torch.nn.functional.linear)
10+
3. Handle special ops used by the optimizer (e.g. aten.add, aten.add_)
11+
"""
12+
13+
import torch
14+
15+
from torch.utils._python_dispatch import return_and_correct_aliasing
16+
from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType
17+
from torchao.dtypes.utils import LayoutType, PlainLayoutType
18+
from my_dtype_tensor_subclass import MyDTypeLayout, MyDTypeTensor
19+
20+
aten = torch.ops.aten
21+
22+
23+
##############################
24+
# Tensor Subclass Definition #
25+
##############################
26+
27+
class MyTrainableDTypeTensor(MyDTypeTensor):
28+
"""
29+
Example tensor subclass that extends `MyDTypeTensor` to support training.
30+
"""
31+
32+
@classmethod
33+
def _quantize(
34+
cls,
35+
input_float: torch.Tensor,
36+
layout_type: LayoutType,
37+
) -> MyDTypeLayout:
38+
"""
39+
Convert from a floating point tensor (fp32/fp16/bf16) to the desired dtype.
40+
"""
41+
mapping_type = MappingType.SYMMETRIC
42+
block_size = input_float.shape
43+
dtype = torch.int16
44+
scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
45+
int_data = (input_float / scale).to(torch.int8)
46+
layout_tensor_ctr = cls.get_layout_tensor_constructor(type(layout_type))
47+
return layout_tensor_ctr(int_data, scale, layout_type)
48+
49+
@classmethod
50+
def from_float(
51+
cls,
52+
input_float: torch.Tensor,
53+
layout_type: LayoutType = PlainLayoutType(),
54+
) -> "MyTrainableDTypeTensor":
55+
"""
56+
Main entry point for creating a `MyTrainableDTypeTensor`.
57+
58+
This instantiates the tensor subclass in a differentiable constructor
59+
to ensure gradients are passed to the tensor subclass properly during training.
60+
"""
61+
return _ToMyTrainableDTypeTensor.apply(input_float, layout_type)
62+
63+
class _ToMyTrainableDTypeTensor(torch.autograd.Function):
64+
"""
65+
Differentiable constructor for `MyTrainableDTypeTensor`.
66+
"""
67+
68+
@staticmethod
69+
def forward(
70+
ctx: torch.autograd.function.FunctionCtx,
71+
input_float: torch.Tensor,
72+
layout_type: LayoutType,
73+
) -> "MyTrainableDTypeTensor":
74+
layout_tensor = MyTrainableDTypeTensor._quantize(input_float, layout_type)
75+
return MyTrainableDTypeTensor(
76+
layout_tensor,
77+
input_float.shape,
78+
requires_grad=True,
79+
)
80+
81+
@staticmethod
82+
def backward(ctx, gy):
83+
return gy, None
84+
85+
to_my_trainable_dtype = MyTrainableDTypeTensor.from_float
86+
87+
88+
#####################################################
89+
# torch functional and aten operator implementation #
90+
#####################################################
91+
92+
implements = MyTrainableDTypeTensor.implements
93+
94+
class _QuantizedLinearOp(torch.autograd.Function):
95+
"""
96+
Forward and backward definition for linear with quantized weights.
97+
Weights are dequantized during both the forward and the backward passes.
98+
"""
99+
100+
@staticmethod
101+
def forward(
102+
ctx: torch.autograd.function.FunctionCtx,
103+
input_tensor: torch.Tensor,
104+
weight_tensor: torch.Tensor,
105+
) -> torch.Tensor:
106+
assert isinstance(weight_tensor, MyTrainableDTypeTensor)
107+
ctx.save_for_backward(input_tensor, weight_tensor)
108+
weight_tensor = weight_tensor.dequantize()
109+
return torch.nn.functional.linear(input_tensor, weight_tensor)
110+
111+
@staticmethod
112+
def backward(ctx, grad_output):
113+
input_tensor, weight_tensor = ctx.saved_tensors
114+
grad_input = torch.matmul(grad_output, weight_tensor.dequantize())
115+
grad_weight = torch.matmul(
116+
grad_output.view(-1, weight_tensor.shape[0]).T,
117+
input_tensor.view(-1, weight_tensor.shape[1]),
118+
)
119+
return grad_input, grad_weight
120+
121+
@implements(torch.nn.functional.linear)
122+
def _(func, types, args, kwargs):
123+
"""
124+
Handle the linear op with quantized weights.
125+
For simplicity, we run both the forward and backward passes entirely in float.
126+
"""
127+
assert isinstance(args[1], MyTrainableDTypeTensor)
128+
if len(args) > 2 and args[2] is not None:
129+
raise NotImplementedError("linear bias not yet supported")
130+
return _QuantizedLinearOp.apply(args[0], args[1])
131+
132+
@implements(aten.add_.Tensor)
133+
def _(func, types, args, kwargs):
134+
"""
135+
Handle the in-place add op, called by the optimizer to update
136+
the quantized weight during training.
137+
"""
138+
assert len(args) == 2
139+
assert isinstance(args[0], MyTrainableDTypeTensor)
140+
assert args[0].layout_tensor.int_data.dtype == torch.int8
141+
float0 = args[0].dequantize()
142+
float1 = args[1].dequantize() if isinstance(args[1], MyTrainableDTypeTensor) else args[1]
143+
new_value = torch.add(float0, float1, **kwargs)
144+
new_layout_tensor = MyTrainableDTypeTensor._quantize(
145+
new_value,
146+
args[0].layout_tensor.get_layout_type(),
147+
)
148+
args[0].layout_tensor = new_layout_tensor
149+
return return_and_correct_aliasing(func, args, kwargs, args[0])
150+
151+
@implements(aten.add.Tensor)
152+
def _(func, types, args, kwargs):
153+
"""Handle the add op, called by the optimizer during training."""
154+
assert len(args) == 2
155+
assert not isinstance(args[0], MyTrainableDTypeTensor)
156+
assert isinstance(args[1], MyTrainableDTypeTensor)
157+
out = torch.add(args[0], args[1].dequantize(), **kwargs)
158+
return return_and_correct_aliasing(func, args, kwargs, out)
159+
160+
161+
########
162+
# Test #
163+
########
164+
165+
class M(torch.nn.Module):
166+
def __init__(self, *args, **kwargs) -> None:
167+
super().__init__(*args, **kwargs)
168+
self.linear = torch.nn.Linear(512, 1024, bias=False)
169+
170+
def forward(self, x: torch.Tensor) -> torch.Tensor:
171+
return self.linear(x)
172+
173+
def main():
174+
m = M().cuda()
175+
NUM_TRAIN_STEPS = 10
176+
VERBOSE = True
177+
178+
# Convert weights to quantized weights
179+
m.linear.weight = torch.nn.Parameter(
180+
to_my_trainable_dtype(m.linear.weight), requires_grad=True,
181+
)
182+
183+
# Dummy training loop
184+
optimizer = torch.optim.SGD(m.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5)
185+
loss_fn = torch.nn.CrossEntropyLoss()
186+
for i in range(NUM_TRAIN_STEPS):
187+
example_inputs = (torch.randn(512).cuda(),)
188+
target = torch.randn(1024).cuda()
189+
output = m(*example_inputs)
190+
loss = loss_fn(output, target)
191+
loss.backward()
192+
if VERBOSE:
193+
weight = m.linear.weight.layout_tensor.int_data.flatten()[:3]
194+
weight_grad = m.linear.weight.grad.flatten()[:3]
195+
print(" * step %s: weight grad = %s, weight value = %s" % (i, weight_grad, weight))
196+
optimizer.step()
197+
optimizer.zero_grad()
198+
199+
if __name__ == "__main__":
200+
main()

0 commit comments

Comments
 (0)