Skip to content

Commit

Permalink
Add MLPMixer test (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgligorijevicTT authored Jan 21, 2025
1 parent da1c355 commit 5f8859b
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 0 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ lit
pybind11
pytest
transformers
fsspec
einops
Empty file.
82 changes: 82 additions & 0 deletions tests/jax/models/mlpmixer/model_implementation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

# This file incorporates work covered by the following copyright and permission
# notice:
# SPDX-FileCopyrightText: Copyright 2024 Google LLC.
# SPDX-License-Identifier: Apache-2.0

# This code is based on google-research/vision_transformer

from typing import Any, Optional

import einops
import flax.linen as nn
import jax.numpy as jnp
import jax


class MlpBlock(nn.Module):
mlp_dim: int

@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
y = nn.Dense(self.mlp_dim)(x)
y = nn.gelu(y)
return nn.Dense(x.shape[-1])(y)


class MixerBlock(nn.Module):
"""Mixer block layer."""

tokens_mlp_dim: int
channels_mlp_dim: int

@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
y = nn.LayerNorm()(x)
y = jnp.swapaxes(y, 1, 2)
y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y)
y = jnp.swapaxes(y, 1, 2)
x = x + y

y = nn.LayerNorm()(x)
y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y)
y = x + y

return y


class MlpMixer(nn.Module):
"""Mixer architecture."""

patches: Any
num_classes: int
num_blocks: int
hidden_dim: int
tokens_mlp_dim: int
channels_mlp_dim: int
model_name: Optional[str] = None

@nn.compact
def __call__(self, inputs: jax.Array) -> jax.Array:
x = nn.Conv(
self.hidden_dim, self.patches.size, strides=self.patches.size, name="stem"
)(
inputs
) # Patch embedding
x = einops.rearrange(x, "n h w c -> n (h w) c")

for _ in range(self.num_blocks):
x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)

x = nn.LayerNorm(name="pre_head_layer_norm")(x)
x = jnp.mean(x, axis=1)

if self.num_classes:
x = nn.Dense(
self.num_classes, kernel_init=nn.initializers.zeros, name="head"
)(x)

return x
104 changes: 104 additions & 0 deletions tests/jax/models/mlpmixer/test_mlpmixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Sequence

import flax.traverse_util
import fsspec
import jax
import jax.numpy as jnp
import numpy
import pytest
from flax import linen as nn
from infra import ModelTester, RunMode

from .model_implementation import MlpMixer

# Hyperparameters for Mixer-B/16
patch_size = 16
num_classes = 21843
num_blocks = 12
hidden_dim = 768
token_mlp_dim = 384
channel_mlp_dim = 3072


class MlpMixerTester(ModelTester):
"""Tester for MlpMixer model."""

# @override
def _get_model(self) -> nn.Module:
patch = jnp.ones((patch_size, patch_size))
return MlpMixer(
patches=patch,
num_classes=num_classes,
num_blocks=num_blocks,
hidden_dim=hidden_dim,
tokens_mlp_dim=token_mlp_dim,
channels_mlp_dim=channel_mlp_dim,
)

@staticmethod
def _retrieve_pretrained_weights() -> Dict:
# TODO(stefan): Discuss how weights should be handled org wide
link = "https://storage.googleapis.com/mixer_models/imagenet21k/Mixer-B_16.npz"
with fsspec.open("filecache::" + link, cache_storage="/tmp/files/") as f:
weights = numpy.load(f, encoding="bytes")
state_dict = {k: v for k, v in weights.items()}
pytree = flax.traverse_util.unflatten_dict(state_dict, sep="/")
return {"params": pytree}

# @override
def _get_forward_method_name(self) -> str:
return "apply"

# @override
def _get_input_activations(self) -> jax.Array:
key = jax.random.PRNGKey(42)
random_image = jax.random.normal(key, (1, 196, 196, 3))
return random_image

# @override
def _get_forward_method_args(self) -> Sequence[Any]:
ins = self._get_input_activations()
weights = self._retrieve_pretrained_weights()

# Required to bypass "Initializer expected to generate shape (16, 16, 3, 768) but got shape (256, 3, 768)"
kernel = weights["params"]["stem"]["kernel"]
kernel = kernel.reshape(-1, 3, hidden_dim)
weights["params"]["stem"]["kernel"] = kernel

# Alternatively, weights could be randomly initialized like this:
# weights = self._model.init(jax.random.PRNGKey(42), ins)

# JAX frameworks have a convention of passing weights as the first argument
return [weights, ins]


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> MlpMixerTester:
return MlpMixerTester()


@pytest.fixture
def training_tester() -> MlpMixerTester:
return MlpMixerTester(RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.skip(
reason="error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal"
)
def test_mlpmixer(inference_tester: MlpMixerTester):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_mlpmixer_training(training_tester: MlpMixerTester):
training_tester.test()

0 comments on commit 5f8859b

Please sign in to comment.