Skip to content

SelfAttention_MLX working #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def zero_module_mlx(module):
return module



class SelfAttention1D_MLX(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -150,4 +151,99 @@ def __init__(self, channels, multiplier=4):
)

def forward(self, x):
return x + self.main(x)
return x + self.main(x)


class SelfAttention_MLX(nn.Module):
def __init__(
self,
channels,
num_heads=8,
num_head_channels=-1,
cond_dim=None,
use_attention_ffn=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(32, channels, pytorch_compatible=True)
self.qkv = nn.Conv2d(channels, channels * 3, 1)
self.cond_dim = cond_dim
if cond_dim is not None and cond_dim > 0:
self.norm_cond = nn.LayerNorm(cond_dim)
self.kv_cond = nn.Linear(cond_dim, channels * 2)
self.proj_out = zero_module_mlx(nn.Conv2d(channels, channels, 1))
if use_attention_ffn:
self.ffn = nn.Sequential(
nn.GroupNorm(32, channels, pytorch_compatible=True),
nn.Conv2d(channels, 4 * channels, 1),
nn.GELU(),
zero_module_mlx(nn.Conv2d(4 * channels, channels, 1)),
)
else:
self.ffn = None

def attention(self, q, k, v, mask=None):
bs, width, length = q.shape
ch = width // self.num_heads
scale = 1 / math.sqrt(math.sqrt(ch))
weight = mx.einsum(
"bct,bcs->bts",
(q * scale).reshape(bs * self.num_heads, ch, length),
(k * scale).reshape(bs * self.num_heads, ch, -1),
) # More stable with f16 than dividing afterwards
if mask is not None:
# Reshape mask to match attention shape
# From [bs, seq_len] to [bs * num_heads, 1, seq_len]
expanded_mask = einops.array_api.repeat(
mask[:, None, :], # Add dimension for broadcasting
"b 1 s -> (b h) 1 s",
h=self.num_heads,
)
# Apply mask
weight = mx.where(expanded_mask, weight, float("-inf"))

weight = mx.softmax(weight, axis=-1)

return mx.einsum(
"bts,bcs->bct", weight, v.reshape(bs * self.num_heads, ch, -1)
).reshape(bs, width, length)

def forward(self, x, cond=None, cond_mask=None):

x = einops.array_api.rearrange(x, "b c h w -> b h w c")
b, h, w, c = x.shape

qkv = self.qkv(self.norm(x))
qkv = einops.array_api.rearrange(qkv, "b h w (three c) -> three b (h w) c", three=3)
q, k, v = qkv[0], qkv[1], qkv[2]

attn_output = self.attention(q, k, v)

if self.cond_dim is not None and cond is not None:
kv_cond = self.kv_cond(self.norm_cond(cond))
kv_cond = einops.array_api.rearrange(kv_cond, "b s (two c) -> two b s c", two=2)
k_cond, v_cond = kv_cond[0], kv_cond[1]
attn_cond = self.attention(q, k_cond, v_cond, cond_mask)
attn_output += attn_cond

attn_output = einops.array_api.rearrange(attn_output, "b (h w) c -> b h w c", h=h, w=w)
h = self.proj_out(attn_output)

x = einops.array_api.rearrange(x, "b h w c -> b c h w")
h = einops.array_api.rearrange(h, "b h w c -> b c h w")
x = x + h

if self.ffn is not None:
x = einops.array_api.rearrange(x, "b c h w -> b h w c")
x = self.ffn(x) + x
x = einops.array_api.rearrange(x, "b h w c -> b c h w")

return x

94 changes: 92 additions & 2 deletions ml-mdm-matryoshka/tests/test_unet_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import numpy as np
import torch

from ml_mdm.models.unet import MLP, SelfAttention1D, TemporalAttentionBlock
from ml_mdm.models.unet import MLP, SelfAttention1D, TemporalAttentionBlock, SelfAttention
from ml_mdm.models.unet_mlx import (
MLP_MLX,
SelfAttention1D_MLX,
TemporalAttentionBlock_MLX,
SelfAttention_MLX
)


Expand Down Expand Up @@ -56,12 +57,101 @@ def test_pytorch_mlp():

# Validate numerical equivalence using numpy
assert np.allclose(

output.detach().numpy(), np.array(mx.stop_gradient(mlx_output)), atol=1e-5
), "Outputs of PyTorch MLP and MLX MLP should match"

print("Test passed for both PyTorch and MLX MLP!")



def test_pytorch_mlx_self_attention():
"""
Test for feature parity between PyTorch and MLX implementations of SelfAttention.
We'll test both the basic self-attention and conditional attention scenarios.
"""
# Define test parameters
channels = 64
batch_size = 2
spatial_size = 8
cond_dim = 32
num_heads = 8

# ===== 1. Test WITH CONDITIONAL INPUT =====
# Create models WITH conditional support
pytorch_attn_with_cond = SelfAttention(
channels=channels,
num_heads=num_heads,
cond_dim=cond_dim, # Enable conditioning
use_attention_ffn=True,
)
mlx_attn_with_cond = SelfAttention_MLX(
channels=channels,
num_heads=num_heads,
cond_dim=cond_dim,
use_attention_ffn=True,
)

# Create conditional inputs
cond_seq_len = 4
pytorch_cond = torch.randn(batch_size, cond_seq_len, cond_dim)
pytorch_cond_mask = torch.ones(batch_size, cond_seq_len)
mlx_cond = mx.array(pytorch_cond.numpy())
mlx_cond_mask = mx.array(pytorch_cond_mask.numpy())

# Run conditional tests
pytorch_input = torch.randn(batch_size, channels, spatial_size, spatial_size)
mlx_input = mx.array(pytorch_input.numpy())

# PyTorch conditional forward
pytorch_output_with_cond = pytorch_attn_with_cond(
pytorch_input, cond=pytorch_cond, cond_mask=pytorch_cond_mask
)
# MLX conditional forward
mlx_output_with_cond = mlx_attn_with_cond.forward(
mlx_input, cond=mlx_cond, cond_mask=mlx_cond_mask
)

# ===== 2. Test WITHOUT CONDITIONAL INPUT =====
# Create NEW models WITHOUT conditional support
pytorch_attn_no_cond = SelfAttention(
channels=channels,
num_heads=num_heads,
cond_dim=None,
use_attention_ffn=True,
)
mlx_attn_no_cond = SelfAttention_MLX(
channels=channels,
num_heads=num_heads,
cond_dim=None,
use_attention_ffn=True,
)

# Run non-conditional tests
pytorch_output_no_cond = pytorch_attn_no_cond(pytorch_input)
mlx_output_no_cond = mlx_attn_no_cond.forward(mlx_input)

# ===== Assertions =====
# Check conditional outputs
assert pytorch_output_with_cond.shape == pytorch_input.shape
assert mlx_output_with_cond.shape == mlx_input.shape
assert np.allclose(
pytorch_output_with_cond.detach().numpy(),
np.array(mlx_output_with_cond),
atol=1e-5, rtol=1e-5
), "Outputs of PyTorch and MLX attention should match"

# Check non-conditional outputs
assert pytorch_output_no_cond.shape == pytorch_input.shape
assert mlx_output_no_cond.shape == mlx_input.shape
assert np.allclose(
pytorch_output_no_cond.detach().numpy(),
np.array(mlx_output_no_cond),
atol=1e-5, rtol=1e-5
), "Outputs without conditioning should match"

print("Self-attention test passed for both PyTorch and MLX!")

def test_self_attention_1d():
# Define parameters
channels = 8
Expand Down Expand Up @@ -156,4 +246,4 @@ def test_pytorch_mlx_temporal_attention_block():
atol=1e-1, # Significantly increased tolerance
), "Outputs of PyTorch and MLX TemporalAttentionBlock should match"

print("Test passed for both PyTorch and MLX TemporalAttentionBlock!")
print("Test passed for both PyTorch and MLX TemporalAttentionBlock!")