Description
This is mostly to keep track of this problem which has been around for a while
if you ever do something like 1)quantize cpu model with int4, 2)move it to cuda
then the output of the model will be nonsense.
e.g. if in https://github.com/pytorch/ao/blob/main/torchao/_models/llama/generate.py#L231
you did
quantize_(model.cpu(), int4_weight_only(group_size=groupsize))
model.cuda()
the output of hte model is nonsensical
Hello, my name is♠ zewnętrz zewnętrz@{ zewnętrz zewnętrz zewnętrz))] ord zewnętrzŻ zewnętrz zewnętrz zewnętrz zewnętrzŻ zewnętrz zewnętrz Хронологи
because it simply changes same packed weight from cpu to cuda without addressing teh fact that the format is numerically different for each backend
despite the different packing paths there's no metadata to detect which backend packing algorithm was actually used so can't even error out intelligently.
We could manually keep track of this in affine quantized tensor and having code to unpack and repack if someone calls .to(device) but it doesn't fully solve the issue because again, we can't detect it. Users can do stuff like serialize the model on cuda, reload on cpu and we're in the same situation because when you try to do .cuda() you would want to unpack->repack but would use the cpu unpacking which wont work since hte original packing was done on cuda. You'd have to further add a field to keep track of which device the packed weight was most recently packed in and if someone tries to do .to("device") you have to check what the original device was, and if its different from the current device then you first move it before the unpack->repack. We should either implement such a solution or identify whether this is going to be rectified in some other way.
small repro:
import torch
import torchao
from torchao.quantization import quantize_, int4_weight_only
import copy
model = torch.nn.Linear(1024, 1024, dtype=torch.bfloat16)
input = torch.randn(1024, 1024, device="cuda", dtype=torch.bfloat16)
model_q_cpu=copy.deepcopy(model)
model_q_cuda=copy.deepcopy(model.cuda())
quantize_(model_q_cpu, int4_weight_only())
quantize_(model_q_cuda, int4_weight_only())
out=model_q_cpu.to("cuda")(input) # AQT actually doesn't let you run the model on cpu
out2=model_q_cuda(input)
print(out-out2)