Skip to content

package issues with functions under C extensions #44

Open
@d4l3k

Description

@d4l3k
import torch

from torch.package import PackageExporter, PackageImporter

output_path = "/tmp/model.pt"



def save_load(model):
    with PackageExporter(output_path) as e:
        e.extern("torch.**")
        e.intern("**")
    
        e.save_pickle("model", "model.pkl", model)
    
    imp = PackageImporter(output_path)
    return imp.load_pickle("model", "model.pkl")

    print("pass")


model = torch.nn.TransformerEncoderLayer(
        d_model=64,
        nhead=2,
    dim_feedforward=64,
    dropout=1.0,
    batch_first=True,
    activation='gelu',
    norm_first=True,
)
save_load(model)

The issue is that F.gelu can't be loaded from package due to a nimport error

ModuleNotFoundError: No module named 'torch._C._nn'; 'torch._C' is not a package

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingpackagetorch.package and related

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions