v1.4.12
What's Changed
- dev(hansbug): try adapt torch2 by @HansBug in #85
- dev(narugo): torch compile's test by @HansBug in #87
In the new version (v1.4.12), support for torch >= 2 versions has been added, including support for torch.compile
for faster inference and backpropagation. Here's an example:
from typing import Tuple, Mapping
import torch
from torch import nn
from treevalue import FastTreeValue
# A simple MLP
class MLP(nn.Module):
def __init__(self, in_features: int, out_features: int, layers: Tuple[int, ...] = (1024,)):
nn.Module.__init__(self)
self.in_features = in_features
self.out_features = out_features
self.layers = layers
ios = [self.in_features, *self.layers, self.out_features]
self.mlp = nn.Sequential(
*(
nn.Linear(in_, out_, bias=True)
for in_, out_ in zip(ios[:-1], ios[1:])
)
)
def forward(self, x):
return self.mlp(x)
# Multiple headed MLP
class MultiHeadMLP(nn.Module):
def __init__(self, in_features: int, out_features: Mapping[str, int], layers: Tuple[int, ...] = (1024,)):
nn.Module.__init__(self)
self.in_features = in_features
self.out_features = out_features
self.layers = layers
_networks = {
o_name: MLP(in_features, o_feat, layers)
for o_name, o_feat in self.out_features.items()
}
self.mlps = nn.ModuleDict(_networks) # use nn.ModuleDict to register child MLPs
self._t_mlps = FastTreeValue(_networks) # use TreeValue for batch inferring
def forward(self, x):
return self._t_mlps(x)
if __name__ == '__main__':
net = MultiHeadMLP(
20,
{'a': 10, 'b': 20, 'c': 14, 'd': 3},
)
net = torch.compile(net)
print(net)
input_ = torch.randn(1, 10, 20)
output = net(input_)
print(output.shape)
The compiled version of the MultiHeadMLP above will have the following network structure:
OptimizedModule(
(_orig_mod): MultiHeadMLP(
(mlps): ModuleDict(
(a): MLP(
(mlp): Sequential(
(0): Linear(in_features=20, out_features=1024, bias=True)
(1): Linear(in_features=1024, out_features=10, bias=True)
)
)
(b): MLP(
(mlp): Sequential(
(0): Linear(in_features=20, out_features=1024, bias=True)
(1): Linear(in_features=1024, out_features=20, bias=True)
)
)
(c): MLP(
(mlp): Sequential(
(0): Linear(in_features=20, out_features=1024, bias=True)
(1): Linear(in_features=1024, out_features=14, bias=True)
)
)
(d): MLP(
(mlp): Sequential(
(0): Linear(in_features=20, out_features=1024, bias=True)
(1): Linear(in_features=1024, out_features=3, bias=True)
)
)
)
)
)
And the inference output after passing float32[1, 10, 20]
as input will have the following dimensions:
<FastTreeValue 0x7fe9b197e6a0>
├── 'a' --> torch.Size([1, 10, 10])
├── 'b' --> torch.Size([1, 10, 20])
├── 'c' --> torch.Size([1, 10, 14])
└── 'd' --> torch.Size([1, 10, 3])
Full Changelog: v1.4.11...v1.4.12