-
Notifications
You must be signed in to change notification settings - Fork 4
Converting models to jax
PyTorch is the machine learning framework of choice for many researchers and developers due to its flexibility, ease of use, and dynamic computation graph. It boasts a large ecosystem of established frameworks that facilitate distributed training and hyperparameter tuning. Additionally, it includes high-level abstractions that simplify the development of neural networks. TorchScript, a subset of PyTorch, enables the creation of serializable and optimizable models which can run independently of Python, making it suitable for deployment in production environments.
JAX, on the other hand, is a library designed for high-performance numerical computing. Its use of pure functions and explicit state management, along with its efficient XLA (Accelerated Linear Algebra) compiler, allows for highly optimized code execution on both CPUs and GPUs, as well as TPU support.
Early in the project we decided that we wanted to implement neural network potentials in PyTorch and develop the MCMC code in JAX. This requires that we either use compiled pytorch models in JAX or translate trained models between frameworks.
First, let's discuss the most obvious approach that doesn't require any translation:
-
Pytorch2Jax: wraps a PyTorch model in jax/Flax. It uses
dlpack
to convert between pytorch and jax tensors in memory (and on the same device). The wrapper functions are compatible with Jax auto diff.
There are multiple viable possibilities to translate between the frameworks
-
Torch2Jax Torch2Jax uses tracing to move JAX values through PyTorch code. This results in a JAX-native computation graph that follows the PyTorch code (it is literally mapping PyTorch functions to JAX functions)
-
Unify framework Unify framework lets you write code in Python and convert it to any unified AI framework (JAX, TF, PyTorch, and NumPy). It can convert models from PyTorch to JAX https://unify.ai/docs/ivy/demos/learn_the_basics/04_transpile_code.html
-
Open Neural Network Exchange : ONNX is built into PyTorch, so converting a Pytorch model to an ONNX model is trivial. The conversion from ONNX to JAX is doable but not straightforward: https://velog.io/@thdalwh3867/Convert-PyTorch-models-to-Flax
This approach will use the PyTorch backend inside a JAX function or Flex module. It uses dlpack
to avoid moving PyTorch tensors and JAX arrays between devices when converting from one framework to another. We can still use the JAX autograd capability.
To use this approach, only minimal modifications were necessary: Every neural Network's input is the NNPInput data class. We opted for a class to wrap our input because we can do extensive input checking and modifications. To ensure that we pass the input that is usable by JAX, we added a method that converts the data class fields to JAX and returns the data class as a NamedTuple.
To wrap a PyTorch model, we have a slightly modified version of the PyTorch2Jax code implemented in a converted class. This class is responsible for transforming the input and output between PyTorch tensors and JAX arrays on the used device. It also handles the model parameters and buffer conversion so that JAX can take derivatives.
To obtain a JAX-wrapped PyTorch model you need to specify JAX
as the simulation_environment
in the NeuralNetworkPotentialFactory
:
# inference model
model = NeuralNetworkPotentialFactory.create_nnp(
use= "inference",
nnp_type=model_name,
simulation_environment= "JAX",
)
Unify implements 2 functions that are of interest:
-
ivy.trace_graph: similar application than
torch.trace
: records the compute graph and removes all Python wrapping overhead - ivy.transpile: converts a function written in any framework into your framework of choice, preserving all the logic between frameworks. A general introduction can be found here.
We will start using the transpile capability. This approach will translate the PyTorch model to a flax model (a library built on top of JAX) for further usage in chiron
.
To demonstrate the approach the following example translates a simple PyTorch model to JAX:
import ivy
import torch
import torch.nn as nn
ivy.set_backend("jax")
class MultiLayerNetwork(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size):
"""
Initializes the MultiLayerNetwork.
Parameters:
input_size (int): The size of the input features.
hidden_sizes (list of int): A list containing the sizes of the hidden layers.
output_size (int): The size of the output layer.
"""
super().__init__()
layers = []
current_size = input_size
for hidden_size in hidden_sizes:
layers.append(nn.Linear(current_size, hidden_size))
layers.append(nn.ReLU())
current_size = hidden_size
# Output layer
layers.append(nn.Linear(current_size, output_size))
# Register all layers
self.layers = nn.Sequential(*layers)
def forward(self, x):
"""
Forward pass through the network.
"""
return self.layers(x)
# Example usage:
net = MultiLayerNetwork(input_size=10, hidden_sizes=[50, 30], output_size=1)
inp = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
torch_result = net(inp)
# Transpile it into a flax model with corresponding parameters
flax_net = ivy.transpile(net, source='torch', to="flax", args=[inp])
import jax.numpy as jnp
from jax import random
# input for jax
inp = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
# get variables
variables = mlp_encoder.init(random.key(0), inp)
jax_result = mlp_encoder.apply(variables, inp)
if jnp.isclose(torch_result.detach().numpy(), jax_result):
print('the same')
This worked great. Only slight modifications were necessary to get the converted model following the documentation.
Let's see how it performs with the modelforge
NNPs. To test this we start with a SchNet
model:
# import the models implemented in modelforge
from modelforge.potential import NeuralNetworkPotentialFactory
from modelforge.dataset.qm9 import QM9Dataset
from modelforge.dataset.dataset import TorchDataModule
import torch
# obtain a single batch
data = QM9Dataset(for_unit_testing=True)
dataset = TorchDataModule(
data, batch_size=1
)
dataset.prepare_data(remove_self_energies=True, normalize=False)
batch = dataset.train_dataloader().__iter__().__next__()
# initialize the model
model = NeuralNetworkPotentialFactory.create_nnp("inference", "SchNet")
model = model.to(torch.float32)
# predict energy
torch_result = model(batch.nnp_input)
# transpile the model
import ivy
ivy.set_backend("jax")
jax_model = ivy.transpile(model, source='torch', to="jax", args=[batch.nnp_input])
# predict energy for the same batch
import jax.numpy as jnp
from jax import random
# convert batch to jax input
jax_input = batch.nnp_input.as_jax_namedtuple()
variables = jax_model.init(random.key(0), jax_input)
jax_result = jax_model.apply(variables, jax_input)
if jnp.isclose(torch_result.E.detach().numpy(), jax_result.E):
print('the same')
Again, this worked out of the box and produces a NNP that is implemented in Flax
.
NOTE: Since we want to use the neighbor list implemented in chiron
we don't want to convert the whole model.
For that reason we separate each model in a InputPreparation
module and a CoreModel
module, as outlined here: https://github.com/choderalab/modelforge/wiki/Neural-network-potentials#module-organization.