Closed
Description
FP8 Linear does not work for me:
- torch == 2.4.0 + cu121
- torchao == 0.4.0
- cuda_arch == 8.9 (nvidia L40)
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
class FFN(nn.Module):
def __init__(self, in_feature, hidden_feature, bias=True):
super().__init__()
self.fc1 = nn.Linear(in_feature, hidden_feature, bias)
self.fc2 = nn.Linear(hidden_feature, in_feature, bias)
self.gelu = nn.GELU()
def forward(self, x):
x = self.fc1(x)
x = self.gelu(x)
x = self.fc2(x)
return x
bs, seq, dim = 32, 512, 1024
m = FFN(dim, dim * 4).cuda()
convert_to_float8_training(m)
# m = torch.compile(m)
x = torch.randn((bs, seq, dim), device="cuda")
with torch.inference_mode(mode=True):
y = m(x)
/usr/local/lib/python3.10/dist-packages/torchao/ops.py:12: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
return torch.library.impl_abstract(f"{name}")(func)
Traceback (most recent call last):
File "/root/erdos/ops/triton/t.py", line 28, in <module>
y = m(x)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/erdos/ops/triton/t.py", line 14, in forward
x = self.fc1(x)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 363, in forward
output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 59, in forward
input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_tensor.py", line 360, in __torch_dispatch__
raise NotImplementedError(f"attempting to run {func}, this is not supported")
NotImplementedError: attempting to run aten.reshape.default, this is not supported