Closed
Description
I'm trying to compile gnn model with torch script and get following error:
RuntimeError:
Unknown type name 'torch_geometric.nn.MLP':
File "/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/aggr/attention.py", line 55
self.assert_two_dimensional_input(x, dim)
if isinstance(self.gate_nn, torch_geometric.nn.MLP):
~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
gate = self.gate_nn(x, index, dim_size)
else:
Here is example to reproduce this error:
import torch
from torch import nn
from torch_geometric.nn import aggr, TransformerConv
from torch_geometric.nn.norm import GraphNorm
class GNN(nn.Module):
def __init__(self):
super(GNN, self).__init__()
self.gnn = TransformerConv(10, 10, 1, dropout=0.1).jittable()
self.normalization = GraphNorm(in_channels=10)
self.act = nn.GELU()
self.aggr = aggr.AttentionalAggregation(gate_nn=nn.Linear(10, 1), nn=nn.Linear(10, 10)) #doesn't compile
def forward(self, x, edge_index, batch_ptr):
hidden = x
hidden = self.gnn(hidden, edge_index)
hidden = self.normalization(hidden)
hidden = self.act(hidden)
graph_rep = self.aggr(hidden, ptr=batch_ptr)
return graph_rep
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = GNN().to(device)
x = torch.randn(100, 10).to(device)
edge_index = torch.randint(100, size=(2, 20)).to(device)
batch_ptr = torch.tensor([0, 100]).to(device)
res = model(x, edge_index, batch_ptr)
print(res)
model = torch.jit.script(model)
Environment
- PyG version: 2.4.0
- PyTorch version: 2.1.0
- OS: ubuntu 22
- Python version: 3.10
- CUDA/cuDNN version: 12.3
- How you installed PyTorch and PyG (
conda
,pip
, source): pip/source
Activity