Skip to content

[BUG] Float8Linear does not work with torch.inference_mode #643

Closed
@leeeizhang

Description

@leeeizhang

FP8 Linear does not work for me:

  • torch == 2.4.0 + cu121
  • torchao == 0.4.0
  • cuda_arch == 8.9 (nvidia L40)
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training

class FFN(nn.Module):
    def __init__(self, in_feature, hidden_feature, bias=True):
        super().__init__()
        self.fc1 = nn.Linear(in_feature, hidden_feature, bias)
        self.fc2 = nn.Linear(hidden_feature, in_feature, bias)
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x


bs, seq, dim = 32, 512, 1024

m = FFN(dim, dim * 4).cuda()
convert_to_float8_training(m)
# m = torch.compile(m)

x = torch.randn((bs, seq, dim), device="cuda")

with torch.inference_mode(mode=True):
    y = m(x)
/usr/local/lib/python3.10/dist-packages/torchao/ops.py:12: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  return torch.library.impl_abstract(f"{name}")(func)
Traceback (most recent call last):
  File "/root/erdos/ops/triton/t.py", line 28, in <module>
    y = m(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/erdos/ops/triton/t.py", line 14, in forward
    x = self.fc1(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 363, in forward
    output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 59, in forward
    input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_tensor.py", line 360, in __torch_dispatch__
    raise NotImplementedError(f"attempting to run {func}, this is not supported")
NotImplementedError: attempting to run aten.reshape.default, this is not supported

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions