-
Notifications
You must be signed in to change notification settings - Fork 241
Open
Description
CANINE is a great candidate for being exported to onnx as it operates over raw unicode and thus doesn't require a tokenizer.
However the model currently fails to analyze with
━━━ ..,F32
[2025-02-18T04:59:02.196544000Z ERROR tract] Error at stage "analyse"
Caused by:
0: ModelBuildingError
1: Failed analyse for node #295 "/model/If" If
2: Infering facts
3: Failed analyse for node #1 "/model/Squeeze" Squeeze13
4: Infering facts
5: Applying rule GivenRule { (inputs[0].shape, inputs[1]) }
6: Attempt to squeeze an axis which dimension is not one Squeeze { axes: Some([-1]) }, [Sym(batch_size), Val(1), Val(16)]See attached code to export (unfortunately the zip was too large to export)
mport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import os
import onnx
from pathlib import Path
from transformers import CanineModel
class TokenizerWrapper(nn.Module):
"""
A wrapper that handles only the numerical parts of tokenization
that can be exported to ONNX
"""
def __init__(self):
super().__init__()
self.model = CanineModel.from_pretrained("google/canine-c") # mo
def forward(self, text):
"""
Args:
input_ids: Tensor of token ids [batch_size, seq_length]
attention_mask: Tensor of attention mask [batch_size, seq_length]
"""
# Simple validation that can be exported to ONNX
res = self.model(text)
return res.pooler_output, res.last_hidden_state
def export_tokenizer_as_onnx(tokenizer_path, onnx_path, max_length=64):
"""
Export the numerical components of tokenization to ONNX
"""
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
# Create wrapper module
wrapper = TokenizerWrapper()
wrapper.eval()
# Create dummy inputs
dummy_ids = torch.randint(0, tokenizer.vocab_size, (1, max_length))
# Export to ONNX
torch.onnx.export(
wrapper,
dummy_ids,
onnx_path,
input_names=['text'],
output_names=['output_ids', 'output_mask'],
dynamic_axes={
'text': {0: 'batch_size'},
'output_ids': {0: 'batch_size'},
'output_mask': {0: 'batch_size'},
},
opset_version=13
)
# Verify the model
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
# Get file sizes
onnx_size = Path(onnx_path).stat().st_size / (1024 * 1024) # Convert to MB
config_size = sum(f.stat().st_size for f in Path(onnx_path + "_config").glob('**/*') if f.is_file()) / (1024 * 1024)
print(f"\nFile sizes:")
print(f"ONNX model: {onnx_size:.2f} MB")
return onnx_model
def preprocess(text, tokenizer_path, max_length=64):
"""
Tokenize text and prepare tensors for ONNX model
"""
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
# Create wrapper module
wrapper = TokenizerWrapper()
return wrapper.forward(text)
# Export the model
onnx_model = export_tokenizer_as_onnx(
tokenizer_path="tokenizer",
onnx_path="tokenizer.onnx",
)
print("ONNX model exported successfully!")
# Preprocess the file
pooler_output, last_hidden_state = preprocess(
torch.randint(0, tokenizer.vocab_size, (1, 64)),
"elon_tokenizer"
)
print("\nTokenization results:")
print(f"pooler_output shape: {pooler_output.shape}")
print(f"last_hidden_state shape: {last_hidden_state.shape}")Metadata
Metadata
Assignees
Labels
No labels