Skip to content

v1.4.12

Compare
Choose a tag to compare
@HansBug HansBug released this 14 Aug 05:18
· 24 commits to main since this release

What's Changed

In the new version (v1.4.12), support for torch >= 2 versions has been added, including support for torch.compile for faster inference and backpropagation. Here's an example:

from typing import Tuple, Mapping

import torch
from torch import nn

from treevalue import FastTreeValue


# A simple MLP
class MLP(nn.Module):
    def __init__(self, in_features: int, out_features: int, layers: Tuple[int, ...] = (1024,)):
        nn.Module.__init__(self)
        self.in_features = in_features
        self.out_features = out_features
        self.layers = layers
        ios = [self.in_features, *self.layers, self.out_features]
        self.mlp = nn.Sequential(
            *(
                nn.Linear(in_, out_, bias=True)
                for in_, out_ in zip(ios[:-1], ios[1:])
            )
        )

    def forward(self, x):
        return self.mlp(x)


# Multiple headed MLP
class MultiHeadMLP(nn.Module):
    def __init__(self, in_features: int, out_features: Mapping[str, int], layers: Tuple[int, ...] = (1024,)):
        nn.Module.__init__(self)
        self.in_features = in_features
        self.out_features = out_features
        self.layers = layers
        _networks = {
            o_name: MLP(in_features, o_feat, layers)
            for o_name, o_feat in self.out_features.items()
        }
        self.mlps = nn.ModuleDict(_networks)  # use nn.ModuleDict to register child MLPs
        self._t_mlps = FastTreeValue(_networks)  # use TreeValue for batch inferring

    def forward(self, x):
        return self._t_mlps(x)


if __name__ == '__main__':
    net = MultiHeadMLP(
        20,
        {'a': 10, 'b': 20, 'c': 14, 'd': 3},
    )
    net = torch.compile(net)
    print(net)

    input_ = torch.randn(1, 10, 20)
    output = net(input_)
    print(output.shape)

The compiled version of the MultiHeadMLP above will have the following network structure:

OptimizedModule(
  (_orig_mod): MultiHeadMLP(
    (mlps): ModuleDict(
      (a): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=10, bias=True)
        )
      )
      (b): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=20, bias=True)
        )
      )
      (c): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=14, bias=True)
        )
      )
      (d): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=3, bias=True)
        )
      )
    )
  )
)

And the inference output after passing float32[1, 10, 20] as input will have the following dimensions:

<FastTreeValue 0x7fe9b197e6a0>
├── 'a' --> torch.Size([1, 10, 10])
├── 'b' --> torch.Size([1, 10, 20])
├── 'c' --> torch.Size([1, 10, 14])
└── 'd' --> torch.Size([1, 10, 3])

Full Changelog: v1.4.11...v1.4.12