|
| 1 | +""" |
| 2 | +This is an example for a tensor subclass representing a simple dtype |
| 3 | +that can be used in training. |
| 4 | +
|
| 5 | +We extend our previous example of `MyDTypeTensor` with a few extra steps |
| 6 | +needed to ensure proper gradient updates during training: |
| 7 | +
|
| 8 | + 1. Define a differentiable constructor |
| 9 | + 2. Define backward pass for ops of interest (e.g. torch.nn.functional.linear) |
| 10 | + 3. Handle special ops used by the optimizer (e.g. aten.add, aten.add_) |
| 11 | +""" |
| 12 | + |
| 13 | +import torch |
| 14 | + |
| 15 | +from torch.utils._python_dispatch import return_and_correct_aliasing |
| 16 | +from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType |
| 17 | +from torchao.dtypes.utils import LayoutType, PlainLayoutType |
| 18 | +from my_dtype_tensor_subclass import MyDTypeLayout, MyDTypeTensor |
| 19 | + |
| 20 | +aten = torch.ops.aten |
| 21 | + |
| 22 | + |
| 23 | +############################## |
| 24 | +# Tensor Subclass Definition # |
| 25 | +############################## |
| 26 | + |
| 27 | +class MyTrainableDTypeTensor(MyDTypeTensor): |
| 28 | + """ |
| 29 | + Example tensor subclass that extends `MyDTypeTensor` to support training. |
| 30 | + """ |
| 31 | + |
| 32 | + @classmethod |
| 33 | + def _quantize( |
| 34 | + cls, |
| 35 | + input_float: torch.Tensor, |
| 36 | + layout_type: LayoutType, |
| 37 | + ) -> MyDTypeLayout: |
| 38 | + """ |
| 39 | + Convert from a floating point tensor (fp32/fp16/bf16) to the desired dtype. |
| 40 | + """ |
| 41 | + mapping_type = MappingType.SYMMETRIC |
| 42 | + block_size = input_float.shape |
| 43 | + dtype = torch.int16 |
| 44 | + scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) |
| 45 | + int_data = (input_float / scale).to(torch.int8) |
| 46 | + layout_tensor_ctr = cls.get_layout_tensor_constructor(type(layout_type)) |
| 47 | + return layout_tensor_ctr(int_data, scale, layout_type) |
| 48 | + |
| 49 | + @classmethod |
| 50 | + def from_float( |
| 51 | + cls, |
| 52 | + input_float: torch.Tensor, |
| 53 | + layout_type: LayoutType = PlainLayoutType(), |
| 54 | + ) -> "MyTrainableDTypeTensor": |
| 55 | + """ |
| 56 | + Main entry point for creating a `MyTrainableDTypeTensor`. |
| 57 | +
|
| 58 | + This instantiates the tensor subclass in a differentiable constructor |
| 59 | + to ensure gradients are passed to the tensor subclass properly during training. |
| 60 | + """ |
| 61 | + return _ToMyTrainableDTypeTensor.apply(input_float, layout_type) |
| 62 | + |
| 63 | +class _ToMyTrainableDTypeTensor(torch.autograd.Function): |
| 64 | + """ |
| 65 | + Differentiable constructor for `MyTrainableDTypeTensor`. |
| 66 | + """ |
| 67 | + |
| 68 | + @staticmethod |
| 69 | + def forward( |
| 70 | + ctx: torch.autograd.function.FunctionCtx, |
| 71 | + input_float: torch.Tensor, |
| 72 | + layout_type: LayoutType, |
| 73 | + ) -> "MyTrainableDTypeTensor": |
| 74 | + layout_tensor = MyTrainableDTypeTensor._quantize(input_float, layout_type) |
| 75 | + return MyTrainableDTypeTensor( |
| 76 | + layout_tensor, |
| 77 | + input_float.shape, |
| 78 | + requires_grad=True, |
| 79 | + ) |
| 80 | + |
| 81 | + @staticmethod |
| 82 | + def backward(ctx, gy): |
| 83 | + return gy, None |
| 84 | + |
| 85 | +to_my_trainable_dtype = MyTrainableDTypeTensor.from_float |
| 86 | + |
| 87 | + |
| 88 | +##################################################### |
| 89 | +# torch functional and aten operator implementation # |
| 90 | +##################################################### |
| 91 | + |
| 92 | +implements = MyTrainableDTypeTensor.implements |
| 93 | + |
| 94 | +class _QuantizedLinearOp(torch.autograd.Function): |
| 95 | + """ |
| 96 | + Forward and backward definition for linear with quantized weights. |
| 97 | + Weights are dequantized during both the forward and the backward passes. |
| 98 | + """ |
| 99 | + |
| 100 | + @staticmethod |
| 101 | + def forward( |
| 102 | + ctx: torch.autograd.function.FunctionCtx, |
| 103 | + input_tensor: torch.Tensor, |
| 104 | + weight_tensor: torch.Tensor, |
| 105 | + ) -> torch.Tensor: |
| 106 | + assert isinstance(weight_tensor, MyTrainableDTypeTensor) |
| 107 | + ctx.save_for_backward(input_tensor, weight_tensor) |
| 108 | + weight_tensor = weight_tensor.dequantize() |
| 109 | + return torch.nn.functional.linear(input_tensor, weight_tensor) |
| 110 | + |
| 111 | + @staticmethod |
| 112 | + def backward(ctx, grad_output): |
| 113 | + input_tensor, weight_tensor = ctx.saved_tensors |
| 114 | + grad_input = torch.matmul(grad_output, weight_tensor.dequantize()) |
| 115 | + grad_weight = torch.matmul( |
| 116 | + grad_output.view(-1, weight_tensor.shape[0]).T, |
| 117 | + input_tensor.view(-1, weight_tensor.shape[1]), |
| 118 | + ) |
| 119 | + return grad_input, grad_weight |
| 120 | + |
| 121 | +@implements(torch.nn.functional.linear) |
| 122 | +def _(func, types, args, kwargs): |
| 123 | + """ |
| 124 | + Handle the linear op with quantized weights. |
| 125 | + For simplicity, we run both the forward and backward passes entirely in float. |
| 126 | + """ |
| 127 | + assert isinstance(args[1], MyTrainableDTypeTensor) |
| 128 | + if len(args) > 2 and args[2] is not None: |
| 129 | + raise NotImplementedError("linear bias not yet supported") |
| 130 | + return _QuantizedLinearOp.apply(args[0], args[1]) |
| 131 | + |
| 132 | +@implements(aten.add_.Tensor) |
| 133 | +def _(func, types, args, kwargs): |
| 134 | + """ |
| 135 | + Handle the in-place add op, called by the optimizer to update |
| 136 | + the quantized weight during training. |
| 137 | + """ |
| 138 | + assert len(args) == 2 |
| 139 | + assert isinstance(args[0], MyTrainableDTypeTensor) |
| 140 | + assert args[0].layout_tensor.int_data.dtype == torch.int8 |
| 141 | + float0 = args[0].dequantize() |
| 142 | + float1 = args[1].dequantize() if isinstance(args[1], MyTrainableDTypeTensor) else args[1] |
| 143 | + new_value = torch.add(float0, float1, **kwargs) |
| 144 | + new_layout_tensor = MyTrainableDTypeTensor._quantize( |
| 145 | + new_value, |
| 146 | + args[0].layout_tensor.get_layout_type(), |
| 147 | + ) |
| 148 | + args[0].layout_tensor = new_layout_tensor |
| 149 | + return return_and_correct_aliasing(func, args, kwargs, args[0]) |
| 150 | + |
| 151 | +@implements(aten.add.Tensor) |
| 152 | +def _(func, types, args, kwargs): |
| 153 | + """Handle the add op, called by the optimizer during training.""" |
| 154 | + assert len(args) == 2 |
| 155 | + assert not isinstance(args[0], MyTrainableDTypeTensor) |
| 156 | + assert isinstance(args[1], MyTrainableDTypeTensor) |
| 157 | + out = torch.add(args[0], args[1].dequantize(), **kwargs) |
| 158 | + return return_and_correct_aliasing(func, args, kwargs, out) |
| 159 | + |
| 160 | + |
| 161 | +######## |
| 162 | +# Test # |
| 163 | +######## |
| 164 | + |
| 165 | +class M(torch.nn.Module): |
| 166 | + def __init__(self, *args, **kwargs) -> None: |
| 167 | + super().__init__(*args, **kwargs) |
| 168 | + self.linear = torch.nn.Linear(512, 1024, bias=False) |
| 169 | + |
| 170 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 171 | + return self.linear(x) |
| 172 | + |
| 173 | +def main(): |
| 174 | + m = M().cuda() |
| 175 | + NUM_TRAIN_STEPS = 10 |
| 176 | + VERBOSE = True |
| 177 | + |
| 178 | + # Convert weights to quantized weights |
| 179 | + m.linear.weight = torch.nn.Parameter( |
| 180 | + to_my_trainable_dtype(m.linear.weight), requires_grad=True, |
| 181 | + ) |
| 182 | + |
| 183 | + # Dummy training loop |
| 184 | + optimizer = torch.optim.SGD(m.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5) |
| 185 | + loss_fn = torch.nn.CrossEntropyLoss() |
| 186 | + for i in range(NUM_TRAIN_STEPS): |
| 187 | + example_inputs = (torch.randn(512).cuda(),) |
| 188 | + target = torch.randn(1024).cuda() |
| 189 | + output = m(*example_inputs) |
| 190 | + loss = loss_fn(output, target) |
| 191 | + loss.backward() |
| 192 | + if VERBOSE: |
| 193 | + weight = m.linear.weight.layout_tensor.int_data.flatten()[:3] |
| 194 | + weight_grad = m.linear.weight.grad.flatten()[:3] |
| 195 | + print(" * step %s: weight grad = %s, weight value = %s" % (i, weight_grad, weight)) |
| 196 | + optimizer.step() |
| 197 | + optimizer.zero_grad() |
| 198 | + |
| 199 | +if __name__ == "__main__": |
| 200 | + main() |
0 commit comments