Skip to content

Converting models to jax

Marcus Wieder edited this page May 1, 2024 · 15 revisions

PyTorch is the machine learning framework of choice for many researchers and developers due to its flexibility, ease of use, and dynamic computation graph. It boasts a large ecosystem of established frameworks that facilitate distributed training and hyperparameter tuning. Additionally, it includes high-level abstractions that simplify the development of neural networks. TorchScript, a subset of PyTorch, enables the creation of serializable and optimizable models which can run independently of Python, making it suitable for deployment in production environments.

JAX, on the other hand, is a library designed for high-performance numerical computing. Its use of pure functions and explicit state management, along with its efficient XLA (Accelerated Linear Algebra) compiler, allows for highly optimized code execution on both CPUs and GPUs, as well as TPU support.

Early in the project we decided that we wanted to implement neural network potentials in PyTorch and develop the MCMC code in JAX. This requires that we either use compiled pytorch models in JAX or translate trained models between frameworks.

First, let's discuss the most obvious approach that doesn't require any translation:

Overview

PyTorch models in JAX

  • Pytorch2Jax: wraps a PyTorch model in jax/Flax. It uses dlpack to convert between pytorch and jax tensors in memory (and on the same device). The wrapper functions are compatible with Jax auto diff.

Translate between PyTorch and JAX

There are multiple viable possibilities to translate between the frameworks

Added implementations

This approach will use the PyTorch backend inside a JAX function or Flex module. It uses dlpack to avoid moving PyTorch tensors and JAX arrays between devices when converting from one framework to another. We can still use the JAX autograd capability.

To use this approach, only minimal modifications were necessary: Every neural Network's input is the NNPInput data class. We opted for a class to wrap our input because we can do extensive input checking and modifications. To ensure that we pass the input that is usable by JAX, we added a method that converts the data class fields to JAX and returns the data class as a NamedTuple.

To wrap a PyTorch model, we have a slightly modified version of the PyTorch2Jax code implemented in a converted class. This class is responsible for transforming the input and output between PyTorch tensors and JAX arrays on the used device. It also handles the model parameters and buffer conversion so that JAX can take derivatives.

To obtain a JAX-wrapped PyTorch model you need to specify JAX as the simulation_environment in the NeuralNetworkPotentialFactory:

    # inference model
    model = NeuralNetworkPotentialFactory.create_nnp(
        use= "inference",
        nnp_type=model_name,
        simulation_environment= "JAX",
    )

Unify implements 2 functions that are of interest:

  • ivy.trace_graph: similar application than torch.trace: records the compute graph and removes all Python wrapping overhead
  • ivy.transpile: converts a function written in any framework into your framework of choice, preserving all the logic between frameworks. A general introduction can be found here.

We will start using the transpile capability. This approach will translate the PyTorch model to a flax model (a library built on top of JAX) for further usage in chiron.

To demonstrate the approach the following example translates a simple PyTorch model to JAX:

import ivy
import torch
import torch.nn as nn

ivy.set_backend("jax")

class MultiLayerNetwork(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        """
        Initializes the MultiLayerNetwork.
        
        Parameters:
            input_size (int): The size of the input features.
            hidden_sizes (list of int): A list containing the sizes of the hidden layers.
            output_size (int): The size of the output layer.
        """
        super().__init__()
        layers = []
        
        current_size = input_size
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(current_size, hidden_size))
            layers.append(nn.ReLU())
            current_size = hidden_size
        
        # Output layer
        layers.append(nn.Linear(current_size, output_size))
        
        # Register all layers
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        """
        Forward pass through the network.

        """
        return self.layers(x)

# Example usage:
net = MultiLayerNetwork(input_size=10, hidden_sizes=[50, 30], output_size=1)
inp = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
torch_result = net(inp)
# Transpile it into a flax model with corresponding parameters
flax_net = ivy.transpile(net, source='torch', to="flax", args=[inp])

import jax.numpy as jnp
from jax import random

# input for jax
inp = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
# get variables
variables = mlp_encoder.init(random.key(0), inp)

jax_result = mlp_encoder.apply(variables, inp)
if jnp.isclose(torch_result.detach().numpy(), jax_result):
    print('the same')

This worked great. Only slight modifications were necessary to get the converted model following the documentation. Let's see how it performs with the modelforge NNPs. To test this we start with a SchNet model:

# import the models implemented in modelforge
from modelforge.potential import NeuralNetworkPotentialFactory
from modelforge.dataset.qm9 import QM9Dataset
from modelforge.dataset.dataset import TorchDataModule

import torch

# obtain a single batch
data = QM9Dataset(for_unit_testing=True)
dataset = TorchDataModule(
    data, batch_size=1
)

dataset.prepare_data(remove_self_energies=True, normalize=False)
batch = dataset.train_dataloader().__iter__().__next__()

# initialize the model
model = NeuralNetworkPotentialFactory.create_nnp("inference", "SchNet")
model = model.to(torch.float32)

# predict energy
torch_result = model(batch.nnp_input)

# transpile the model
import ivy
ivy.set_backend("jax")

jax_model = ivy.transpile(model, source='torch', to="jax", args=[batch.nnp_input])

# predict energy for the same batch
import jax.numpy as jnp
from jax import random

# convert batch to jax input
jax_input = batch.nnp_input.as_jax_namedtuple()
variables = jax_model.init(random.key(0), jax_input)

jax_result = jax_model.apply(variables, jax_input)
if jnp.isclose(torch_result.E.detach().numpy(), jax_result.E):
    print('the same')

Again, this worked out of the box and produces a NNP that is implemented in Flax.

NOTE: Since we want to use the neighbor list implemented in chiron we don't want to convert the whole model. For that reason we separate each model in a InputPreparation module and a CoreModel module, as outlined here: https://github.com/choderalab/modelforge/wiki/Neural-network-potentials#module-organization.