Skip to content

Refactor torchao and tests to use model architectures from torchao.testing.model_architectures #2078

Open
@jainapurva

Description

@jainapurva

PR: #2036, adds standard model architectures to torchao.testing.model_architectures.py. Replace the existing model definitions from torchao and tests to reuse the model definitions from model_architectures.py. If new definitions are found, add them to model_architectures.py

Eg:
Replace

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64, bias=False):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float)
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (
torch.randn(
batch_size, self.linear1.in_features, dtype=dtype, device=device
),
)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
to use torchao.testing.model_architectures.ToyLinearModel

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions