You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Seems like torch.compile doesn't like using views on dtypes. This causes the PYTORCH_COMPILE backend and model=torch.compile(model) to break when view_as_float is set to True :
BackendCompilerFailed: backend='inductor'raised:
LoweringException: NotImplementedError: bitcasttorch.float16todifferentbitwidth type torch.uint8isnotsupportedyet.
Wrapping the view with torch.jit.ignore doesn't work in this case.
Minimal code to reproduce the issue:
A work around would be moving the view call outside dequantize but this will make the code more complicated and will require another call to revert back to float bitpacking.
This is mainly a Pytorch bug, so I created the issue there as well: pytorch/pytorch#120998
Note: It looks like it only happens between types that don't have the same bitwidth (uint8 <> float16). If you force to use float32 view with 3-bit that uses int32 bitpacking, it would work fine.
Seems like
torch.compile
doesn't like using views on dtypes. This causes thePYTORCH_COMPILE
backend andmodel=torch.compile(model)
to break whenview_as_float
is set toTrue
:Wrapping the view with
torch.jit.ignore
doesn't work in this case.Minimal code to reproduce the issue:
A work around would be moving the view call outside dequantize but this will make the code more complicated and will require another call to revert back to float bitpacking.
This is mainly a Pytorch bug, so I created the issue there as well: pytorch/pytorch#120998
@KeremTurgutlu fyi
The text was updated successfully, but these errors were encountered: