-
Notifications
You must be signed in to change notification settings - Fork 791
Open
Description
I decided to try the nnx vs. equinox for performance and am seeing significant differences (3'ish times slower for nnx). Could be that I wrote a poor MLP implementation or made a collosal profiling mistake.
My apologies if the benchmarking itself is flaws or the MLP implementation is incorrect in some way. But if it is the later, it shows that a documented MLP implementa`ton for NNX to copy/paste might help.
System information
- latest released NNX
- Tried on laptop (macos with CPU) as well as colab with an accelerator
Problem you have encountered:
The performance of my test suite on my CPU is
Time taken NNX: 0.00055 seconds
Time taken EQX: 0.00019 seconds
And on the colab T4 GPU runtime
Time taken NNX: 0.00220 seconds
Time taken EQX: 0.00066 seconds
Steps to reproduce:
Test Suite:
import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx import rnglib
from flax.typing import Dtype, PrecisionLike
import equinox as eqx
import time
class MLP(nnx.Module):
def __init__(
self,
in_features: int,
out_features: int,
*,
width: int,
depth: int,
activation: tp.Callable,
rngs: rnglib.Rngs,
use_bias: bool = True,
use_final_bias: bool = True,
final_activation: tp.Optional[tp.Callable] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
):
self.in_features = in_features
self.out_features = out_features
self.width = width
self.depth = depth
self.use_bias = use_bias
self.use_final_bias = use_final_bias
self.activation = activation
self.final_activation = final_activation
assert depth > 0 # skipping specialization of no hidden layers
self.layers = []
self.layers.append(
nnx.Linear(
in_features,
width,
use_bias=use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
for i in range(self.depth - 1):
self.layers.append(
nnx.Linear(
width,
width,
use_bias=self.use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
self.layers.append(self.activation)
self.layers.append(
nnx.Linear(
width,
out_features,
use_bias=self.use_final_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
if self.final_activation is not None:
self.layers.append(self.final_activation)
def __call__(self, x: jax.Array) -> jax.Array:
for layer in self.layers:
x = layer(x)
return x
if __name__ == "__main__":
rngs = nnx.Rngs(0)
@nnx.jit
def my_test(batch, model):
@nnx.jit
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
return loss_val
n_in = 64
n_out = 1
depth = 3
width = 128
activation = nnx.relu
model = MLP(n_in, n_out, width=width, depth=depth, activation=activation, rngs=rngs)
my_batch = jax.random.normal(rngs(), (20, n_in))
# Time NNX
out = my_test(my_batch, model).block_until_ready()
start_time = time.time()
out = my_test(my_batch, model).block_until_ready()
end_time = time.time()
print(f"Time taken NNX: {end_time - start_time:.5f} seconds")
@eqx.filter_jit
def my_test_eqx(batch, model):
@eqx.filter_jit
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
return loss_val
equinox_model = eqx.nn.MLP(n_in, n_out, width_size=width, depth=depth, activation=activation, key=rngs())
# Time Equinox
out = my_test_eqx(my_batch, equinox_model)
start_time = time.time()
out = my_test_eqx(my_batch, equinox_model).block_until_ready()
end_time = time.time()
print(f"Time taken EQX: {end_time - start_time:.5f} seconds") On colab you need to do ! pip install equinox
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels