Skip to content

Significant performance difference of NNX relative to equinox #4045

@jlperla

Description

@jlperla

I decided to try the nnx vs. equinox for performance and am seeing significant differences (3'ish times slower for nnx). Could be that I wrote a poor MLP implementation or made a collosal profiling mistake.

My apologies if the benchmarking itself is flaws or the MLP implementation is incorrect in some way. But if it is the later, it shows that a documented MLP implementa`ton for NNX to copy/paste might help.

System information

  • latest released NNX
  • Tried on laptop (macos with CPU) as well as colab with an accelerator

Problem you have encountered:

The performance of my test suite on my CPU is

Time taken NNX: 0.00055 seconds
Time taken EQX: 0.00019 seconds

And on the colab T4 GPU runtime

Time taken NNX: 0.00220 seconds
Time taken EQX: 0.00066 seconds

Steps to reproduce:

Test Suite:

import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx import rnglib
from flax.typing import Dtype, PrecisionLike
import equinox as eqx
import time

class MLP(nnx.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        width: int,
        depth: int,
        activation: tp.Callable,
        rngs: rnglib.Rngs,
        use_bias: bool = True,
        use_final_bias: bool = True,
        final_activation: tp.Optional[tp.Callable] = None,
        dtype: tp.Optional[Dtype] = None,
        param_dtype: Dtype = jnp.float32,
        precision: PrecisionLike = None,
    ):
        self.in_features = in_features
        self.out_features = out_features
        self.width = width
        self.depth = depth
        self.use_bias = use_bias
        self.use_final_bias = use_final_bias
        self.activation = activation
        self.final_activation = final_activation
        assert depth > 0  # skipping specialization of no hidden layers

        self.layers = []
        self.layers.append(
            nnx.Linear(
                in_features,
                width,
                use_bias=use_bias,
                dtype=dtype,
                param_dtype=param_dtype,
                precision=precision,
                rngs=rngs,
            )
        )
        for i in range(self.depth - 1):
            self.layers.append(
                nnx.Linear(
                    width,
                    width,
                    use_bias=self.use_bias,
                    dtype=dtype,
                    param_dtype=param_dtype,
                    precision=precision,
                    rngs=rngs,
                )
            )
            self.layers.append(self.activation)
        self.layers.append(
            nnx.Linear(
                width,
                out_features,
                use_bias=self.use_final_bias,
                dtype=dtype,
                param_dtype=param_dtype,
                precision=precision,
                rngs=rngs,
            )
        )
        if self.final_activation is not None:
            self.layers.append(self.final_activation)

    def __call__(self, x: jax.Array) -> jax.Array:
        for layer in self.layers:
            x = layer(x)
        return x

if __name__ == "__main__":
    rngs = nnx.Rngs(0)

    @nnx.jit
    def my_test(batch, model):
        @nnx.jit
        def loss_closure(f):
            return jnp.mean(jax.vmap(f)(batch))

        loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
        return loss_val
    n_in = 64
    n_out = 1
    depth = 3
    width = 128
    activation = nnx.relu
    model = MLP(n_in, n_out, width=width, depth=depth, activation=activation, rngs=rngs)
    my_batch = jax.random.normal(rngs(), (20, n_in))

    # Time NNX
    out = my_test(my_batch, model).block_until_ready()
    start_time = time.time()
    out = my_test(my_batch, model).block_until_ready()
    end_time = time.time()
    print(f"Time taken NNX: {end_time - start_time:.5f} seconds")

    @eqx.filter_jit
    def my_test_eqx(batch, model):
        @eqx.filter_jit
        def loss_closure(f):
            return jnp.mean(jax.vmap(f)(batch))

        loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
        return loss_val    
    equinox_model = eqx.nn.MLP(n_in, n_out, width_size=width, depth=depth, activation=activation, key=rngs())

    # Time Equinox
    out = my_test_eqx(my_batch, equinox_model)
    start_time = time.time()
    out = my_test_eqx(my_batch, equinox_model).block_until_ready()
    end_time = time.time()
    print(f"Time taken EQX: {end_time - start_time:.5f} seconds")    

On colab you need to do ! pip install equinox

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions