Skip to content

torch.compile fails with view_as_float=True #18

Closed
@mobicham

Description

@mobicham

Seems like torch.compile doesn't like using views on dtypes. This causes the PYTORCH_COMPILE backend and model=torch.compile(model) to break when view_as_float is set to True :

BackendCompilerFailed: backend='inductor' raised:
LoweringException: NotImplementedError: bitcast torch.float16 to different bitwidth type torch.uint8 is not supported yet.

Wrapping the view with torch.jit.ignore doesn't work in this case.
Minimal code to reproduce the issue:

import torch
from hqq.core.quantize import *

HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)

#######################################################################################
batch_size    = 1
context_size  = 512
compute_dtype = torch.float16
linear_layer  = torch.nn.Linear(4096, 4096)

quant_config  = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, offload_meta=False, view_as_float=True) 
hqq_linear    = HQQLinear(linear_layer, quant_config, compute_dtype=compute_dtype, del_orig=False)

@torch.jit.ignore
def dequantize_Wq_aten(W_q, meta):
	if meta['view_as_float']: W_q = W_q.view(meta['unpack_view_dtype'])
	return hqq_aten.dequantize(W_q, meta['scale'], meta['zero'], meta['shape'], meta['group_size'] if (meta['group_size']) else -1, meta['nbits'], meta['axis'], meta['packing'])

@torch.compile()
def dequantize(hqq_layer):
	return dequantize_Wq_aten(hqq_layer.W_q, hqq_layer.meta)

######################################################################################

#This works: 
hqq_linear.W_q.data = hqq_linear.W_q.data.view(hqq_linear.meta['unpack_view_dtype']) 
W_r = dequantize(hqq_linear)

#This breaks
hqq_linear.W_q.data = hqq_linear.W_q.data.view(compute_dtype) 
W_r = dequantize(hqq_linear)

A work around would be moving the view call outside dequantize but this will make the code more complicated and will require another call to revert back to float bitpacking.

This is mainly a Pytorch bug, so I created the issue there as well: pytorch/pytorch#120998

@KeremTurgutlu fyi

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions