This repository was archived by the owner on Aug 7, 2024. It is now read-only.
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
torch.inference_mode switchesaten.linear.default, this is not supported
#238
Closed
Description
First, congrats for the repo - looks great
I discovered that switching between torch.no_grad
and torch.inference_mode
leads to a switch to aten.linear.default
. Feel free to use this feedback - its likely expected.
# dependencies
'2.4.0.dev20240316+cu121'
float8_experimental commit 88e9e507c56e59c5f17edf513ecbf621b46fc67d
from transformers import AutoModel, AutoTokenizer, PreTrainedModel
import torch
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
"""
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --upgrade
"""
model_name = "michaelfeil/bge-small-en-v1.5"
model: PreTrainedModel = AutoModel.from_pretrained(model_name)
model_orig: PreTrainedModel = AutoModel.from_pretrained(model_name).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokens = tokenizer.batch_encode_plus(["hi","there is no"], return_tensors="pt", padding=True, truncation=True)
tokens = {k: v.cuda() for k,v in tokens.items()}
model = model.to("cpu")
swap_linear_with_float8_linear(model, Float8DynamicLinear)
model = model.to("cuda")
with torch.no_grad():
out1 = model_orig.forward(**tokens)["last_hidden_state"]
out2 = model.forward(**tokens)["last_hidden_state"]
print(out1 - out2, out1, out2, (out1-out2).mean())
with torch.inference_mode():
# breaks with NotImplementedError: attempting to run aten.linear.default, this is not supported
out2 = model.forward(**tokens)["last_hidden_state"]
Metadata
Metadata
Assignees
Labels
No labels