Closed
Description
Seems like torch.compile
doesn't like using views on dtypes. This causes the PYTORCH_COMPILE
backend and model=torch.compile(model)
to break when view_as_float
is set to True
:
BackendCompilerFailed: backend='inductor' raised:
LoweringException: NotImplementedError: bitcast torch.float16 to different bitwidth type torch.uint8 is not supported yet.
Wrapping the view with torch.jit.ignore
doesn't work in this case.
Minimal code to reproduce the issue:
import torch
from hqq.core.quantize import *
HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
#######################################################################################
batch_size = 1
context_size = 512
compute_dtype = torch.float16
linear_layer = torch.nn.Linear(4096, 4096)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, offload_meta=False, view_as_float=True)
hqq_linear = HQQLinear(linear_layer, quant_config, compute_dtype=compute_dtype, del_orig=False)
@torch.jit.ignore
def dequantize_Wq_aten(W_q, meta):
if meta['view_as_float']: W_q = W_q.view(meta['unpack_view_dtype'])
return hqq_aten.dequantize(W_q, meta['scale'], meta['zero'], meta['shape'], meta['group_size'] if (meta['group_size']) else -1, meta['nbits'], meta['axis'], meta['packing'])
@torch.compile()
def dequantize(hqq_layer):
return dequantize_Wq_aten(hqq_layer.W_q, hqq_layer.meta)
######################################################################################
#This works:
hqq_linear.W_q.data = hqq_linear.W_q.data.view(hqq_linear.meta['unpack_view_dtype'])
W_r = dequantize(hqq_linear)
#This breaks
hqq_linear.W_q.data = hqq_linear.W_q.data.view(compute_dtype)
W_r = dequantize(hqq_linear)
A work around would be moving the view call outside dequantize but this will make the code more complicated and will require another call to revert back to float bitpacking.
This is mainly a Pytorch bug, so I created the issue there as well: pytorch/pytorch#120998
@KeremTurgutlu fyi