Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

torch.inference_mode switchesaten.linear.default, this is not supported #238

Closed
@michaelfeil

Description

@michaelfeil

First, congrats for the repo - looks great

I discovered that switching between torch.no_grad and torch.inference_mode leads to a switch to aten.linear.default. Feel free to use this feedback - its likely expected.

# dependencies
'2.4.0.dev20240316+cu121'
float8_experimental commit 88e9e507c56e59c5f17edf513ecbf621b46fc67d
from transformers import AutoModel, AutoTokenizer, PreTrainedModel
import torch

from float8_experimental.float8_linear_utils import (
    swap_linear_with_float8_linear,
)
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear

"""
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --upgrade
"""

model_name = "michaelfeil/bge-small-en-v1.5"
model: PreTrainedModel = AutoModel.from_pretrained(model_name)
model_orig: PreTrainedModel = AutoModel.from_pretrained(model_name).cuda()

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokens = tokenizer.batch_encode_plus(["hi","there is no"], return_tensors="pt", padding=True, truncation=True)
tokens = {k: v.cuda() for k,v in tokens.items()}

model = model.to("cpu")
swap_linear_with_float8_linear(model, Float8DynamicLinear)
model = model.to("cuda")

with torch.no_grad():
    out1 = model_orig.forward(**tokens)["last_hidden_state"]
    out2 = model.forward(**tokens)["last_hidden_state"]

print(out1 - out2, out1, out2, (out1-out2).mean())

with torch.inference_mode():
   # breaks with NotImplementedError: attempting to run aten.linear.default, this is not supported
    out2 = model.forward(**tokens)["last_hidden_state"]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions