Skip to content

ResNet MLX #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 123 additions & 1 deletion ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.
""" U-NET architecture."""

import numpy as np

import math

Expand All @@ -8,6 +11,37 @@
import mlx.core as mx
import mlx.nn as nn

from ml_mdm.models.unet import ResNetConfig


def _fan_in(w):
return np.prod(w.shape[1:])


def _fan_out(w):
return w.shape[0]


def _fan_avg(w):
return 0.5 * (_fan_in(w) + _fan_out(w))


def init_weights(module):
"""Initialize weights of a module using PyTorch's default initialization"""
for k, v in module.parameters().items():
if 'weight' in k:
if isinstance(module, nn.GroupNorm):
# PyTorch initializes GroupNorm weights to 1
module.parameters()[k] = mx.ones_like(v)
else:
# For conv and linear layers, use Kaiming uniform initialization
fan = _fan_in(v)
bound = 1 / np.sqrt(fan)
module.parameters()[k] = mx.random.uniform(low=-bound, high=bound, shape=v.shape)
elif 'bias' in k:
module.parameters()[k] = mx.zeros_like(v)
return module


def zero_module_mlx(module):
"""
Expand Down Expand Up @@ -150,4 +184,92 @@ def __init__(self, channels, multiplier=4):
)

def forward(self, x):
return x + self.main(x)
return x + self.main(x)


class ResNet_MLX(nn.Module):
def __init__(self, time_emb_channels, config: ResNetConfig):
super(ResNet_MLX, self).__init__()
self.config = config
self.norm1 = nn.GroupNorm(
config.num_groups_norm,
config.num_channels,
pytorch_compatible=True,
eps=1e-5 #torch std
)

self.conv1 = nn.Conv2d(
config.num_channels,
config.output_channels,
kernel_size=3,
padding=1,
bias=True
)

self.time_layer = nn.Linear(
time_emb_channels,
config.output_channels * 2
)

# Initialize GroupNorm2 without special initialization
self.norm2 = nn.GroupNorm(
config.num_groups_norm,
config.output_channels,
pytorch_compatible=True,
eps=1e-5
)
self.dropout = nn.Dropout(config.dropout)

# conv2 is zero-initialized
self.conv2 = zero_module_mlx(
nn.Conv2d(
config.output_channels,
config.output_channels,
kernel_size=3,
padding=1,
bias=True
)
)

# Create a 1x1 conv for the residual connection if channels don't match
if self.config.output_channels != self.config.num_channels:
# Rename to conv3 to match PyTorch
self.conv3 = nn.Conv2d(
config.num_channels,
config.output_channels,
kernel_size=1,
bias=True
)

def forward(self, x, temb):
print("pre norm shape: ", x.shape)
h = self.norm1(x)
print("post norm shape: ", h.shape)
h = nn.silu(h)
h = self.conv1(h)

temb_out = nn.silu(temb)
temb_out = self.time_layer(temb_out)
temb_out = mx.expand_dims(mx.expand_dims(temb_out, axis=1), axis=1)
ta, tb = mx.split(temb_out, 2, axis=-1)

# Handle batch size mismatch
if h.shape[0] > ta.shape[0]:
N = h.shape[0] // ta.shape[0]
ta = mx.repeat(ta, N, axis=0)
tb = mx.repeat(tb, N, axis=0)

# Broadcast temporal embeddings
ta = mx.broadcast_to(ta, h.shape)
tb = mx.broadcast_to(tb, h.shape)

h = nn.silu(self.norm2(h) * (1 + ta) + tb)
h = self.dropout(h)
h = self.conv2(h)

# Handle residual connection
if self.config.output_channels != self.config.num_channels:
x = self.conv3(x)

return h + x

95 changes: 93 additions & 2 deletions ml-mdm-matryoshka/tests/test_unet_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import numpy as np
import torch

from ml_mdm.models.unet import MLP, SelfAttention1D, TemporalAttentionBlock
from ml_mdm.models.unet import MLP, SelfAttention1D, TemporalAttentionBlock, ResNet, ResNetConfig
from ml_mdm.models.unet_mlx import (
MLP_MLX,
SelfAttention1D_MLX,
ResNet_MLX,
TemporalAttentionBlock_MLX,
init_weights,
zero_module_mlx
)



def test_pytorch_mlp():
"""
Simple test for our MLP implementations
Expand Down Expand Up @@ -56,11 +60,98 @@ def test_pytorch_mlp():

# Validate numerical equivalence using numpy
assert np.allclose(

output.detach().numpy(), np.array(mx.stop_gradient(mlx_output)), atol=1e-5
), "Outputs of PyTorch MLP and MLX MLP should match"

print("Test passed for both PyTorch and MLX MLP!")

def test_pytorch_mlx_ResNet():
"""Test that PyTorch and MLX ResNet implementations produce matching outputs."""
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
mx.random.seed(42)

# Define parameters
batch_size = 2
time_emb_channels = 32
height = 16
width = 16

# Create config
config = ResNetConfig(
num_channels=64,
output_channels=64, # Match input channels for testing
num_groups_norm=32,
dropout=0.0, # Set to 0 for deterministic comparison
use_attention_ffn=False,
)

# Create model instances
pytorch_resnet = ResNet(time_emb_channels=time_emb_channels, config=config)
mlx_resnet = ResNet_MLX(time_emb_channels=time_emb_channels, config=config)

# Initialize weights for MLX model
init_weights(mlx_resnet.norm1)
init_weights(mlx_resnet.conv1)
init_weights(mlx_resnet.time_layer)
init_weights(mlx_resnet.norm2)
mlx_resnet.conv2 = zero_module_mlx(mlx_resnet.conv2)
if hasattr(mlx_resnet, 'conv3'):
init_weights(mlx_resnet.conv3)

# Ensure weights have correct shapes for GroupNorm
if hasattr(mlx_resnet.norm1, 'weight'):
mlx_resnet.norm1.weight = mx.array(np.ones(config.num_channels))
mlx_resnet.norm1.bias = mx.array(np.zeros(config.num_channels))
if hasattr(mlx_resnet.norm2, 'weight'):
mlx_resnet.norm2.weight = mx.array(np.ones(config.output_channels))
mlx_resnet.norm2.bias = mx.array(np.zeros(config.output_channels))

# Set both models to evaluation mode
pytorch_resnet.eval()
mlx_resnet.eval()

# Create input tensors with same random seed for reproducibility
torch.manual_seed(42)
# Create input tensor with num_channels (64) channels
x_torch = torch.randn(batch_size, config.num_channels, height, width) # [2, 64, 16, 16]
temb_torch = torch.randn(batch_size, time_emb_channels) # [2, 32]

# Get PyTorch output first
with torch.no_grad():
output_torch = pytorch_resnet(x_torch, temb_torch)

# Convert inputs to MLX format (NCHW -> NHWC)
x_numpy = x_torch.detach().numpy()
x_numpy = np.transpose(x_numpy, (0, 2, 3, 1)) # NCHW -> NHWC
x_mlx = mx.array(x_numpy)
temb_mlx = mx.array(temb_torch.detach().numpy())

# Debug shapes and intermediate values
print("\nInput shapes:")
print("PyTorch x (NCHW):", x_torch.shape)
print("MLX x (NHWC):", x_mlx.shape)
print("PyTorch temb:", temb_torch.shape)
print("MLX temb:", temb_mlx.shape)

# Debug intermediate values in MLX
# Convert input to NCHW for MLX processing
x_nchw = mx.transpose(x_mlx, [0, 3, 1, 2])
output_mlx = mlx_resnet.forward(x_mlx, temb_mlx)

# Convert MLX output to NCHW format for comparison
output_mlx_numpy = np.array(output_mlx)
output_mlx_numpy = np.transpose(output_mlx_numpy, (0, 3, 1, 2)) # NHWC -> NCHW

# Compare outputs
np.testing.assert_allclose(
output_torch.detach().numpy(),
output_mlx_numpy,
rtol=1e-4,
atol=1e-4,
)

def test_self_attention_1d():
# Define parameters
Expand Down Expand Up @@ -156,4 +247,4 @@ def test_pytorch_mlx_temporal_attention_block():
atol=1e-1, # Significantly increased tolerance
), "Outputs of PyTorch and MLX TemporalAttentionBlock should match"

print("Test passed for both PyTorch and MLX TemporalAttentionBlock!")
print("Test passed for both PyTorch and MLX TemporalAttentionBlock!")