Skip to content

SelfAttention_MLX #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

bdeanhardt
Copy link
Contributor

No description provided.

@bdeanhardt
Copy link
Contributor Author

bug :(
========================================================================================= test session starts =========================================================================================
platform darwin -- Python 3.11.5, pytest-7.4.0, pluggy-1.0.0
rootdir: /Users/belladeanhardt/ml-mdm
configfile: pyproject.toml
plugins: cov-5.0.0, anyio-3.5.0, mock-3.14.0
collected 2 items

tests/test_mlx_unet.py .F [100%]

============================================================================================== FAILURES ===============================================================================================
___________________________________________________________________________________ test_pytorch_mlx_self_attention ___________________________________________________________________________________

def test_pytorch_mlx_self_attention():
    """
    Test for feature parity between PyTorch and MLX implementations of SelfAttention.
    We'll test both the basic self-attention and conditional attention scenarios.
    """
    # Define test parameters
    channels = 64  # Number of channels
    batch_size = 2  # Batch size
    spatial_size = 8  # Spatial dimensions (H=W=8)
    cond_dim = 32  # Conditional dimension
    num_heads = 8  # Number of attention heads

    # Create model instances
    pytorch_attn = SelfAttention(
        channels=channels,
        num_heads=num_heads,
        cond_dim=cond_dim,
        use_attention_ffn=True,
    )
    mlx_attn = SelfAttention_MLX(  # Assuming this is your MLX class name
        channels=channels,
        num_heads=num_heads,
        cond_dim=cond_dim,
        use_attention_ffn=True,
    )

    # Set models to evaluation mode
    pytorch_attn.eval()
    mlx_attn.eval()

    # Create test inputs
    # Main input: [B, C, H, W]
    pytorch_input = torch.randn(batch_size, channels, spatial_size, spatial_size)
    # Conditional input: [B, seq_len, cond_dim]
    cond_seq_len = 4
    pytorch_cond = torch.randn(batch_size, cond_seq_len, cond_dim)
    # Conditional mask: [B, seq_len]
    pytorch_cond_mask = torch.ones(batch_size, cond_seq_len)

    # Test PyTorch version
    pytorch_output = pytorch_attn(
        pytorch_input, cond=pytorch_cond, cond_mask=pytorch_cond_mask
    )

    # Convert inputs to MLX format
    mlx_input = mx.array(pytorch_input.numpy())
    mlx_cond = mx.array(pytorch_cond.numpy())
    mlx_cond_mask = mx.array(pytorch_cond_mask.numpy())

    # Test MLX version
  mlx_output = mlx_attn.forward(mlx_input, cond=mlx_cond, cond_mask=mlx_cond_mask)

tests/test_mlx_unet.py:111:


ml_mdm/models/unet_mlx.py:126: in forward
h = self.proj_out(h)


self = Conv2d(64, 64, kernel_size=(1,), stride=(1, 1), padding=(0, 0), dilation=1, groups=1, bias=True)
x = array([[[[0.275203, 0.260433, 0.149972, ..., 0.170037, 0.294273, 0.0233037],
[0.238134, 0.305211, 0.304386, ....0327954, 0.0905073],
[0.0745733, 0.266315, 0.0139441, ..., -0.124577, -0.181556, 0.0526186]]]], dtype=float32)

def __call__(self, x):
  y = mx.conv2d(
        x, self.weight, self.stride, self.padding, self.dilation, self.groups
    )

E ValueError: [conv] Expect the input channels in the input and weight array to match but got shapes - input: (2,64,8,8) and weight: (64,1,1,64)

../anaconda3/lib/python3.11/site-packages/mlx/nn/layers/convolution.py:157: ValueError
---------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------
input tensor for group norm: (2, 8, 8, 64)
output tensor for groupnorm / input tensor for self attention: (2, 3, 64, 64)
output tensor for self attention: (2, 3, 64, 64)

---------- coverage: platform darwin, python 3.11.5-final-0 ----------
Name Stmts Miss Cover

ml_mdm/init.py 1 0 100%
ml_mdm/clis/init.py 0 0 100%
ml_mdm/clis/download_tar_from_index.py 198 198 0%
ml_mdm/clis/generate_batch.py 139 139 0%
ml_mdm/clis/generate_sample.py 225 225 0%
ml_mdm/clis/run_torchmetrics.py 120 120 0%
ml_mdm/clis/scrape_cc12m.py 62 62 0%
ml_mdm/clis/train_parallel.py 157 157 0%
ml_mdm/config.py 127 85 33%
ml_mdm/diffusion.py 192 135 30%
ml_mdm/distributed.py 39 28 28%
ml_mdm/generate_html.py 18 18 0%
ml_mdm/helpers.py 8 8 0%
ml_mdm/language_models/init.py 0 0 100%
ml_mdm/language_models/factory.py 68 68 0%
ml_mdm/language_models/self_attention.py 5 5 0%
ml_mdm/language_models/tokenizer.py 118 105 11%
ml_mdm/language_models/transformer.py 5 5 0%
ml_mdm/lr_scaler.py 18 18 0%
ml_mdm/models/init.py 1 0 100%
ml_mdm/models/model_ema.py 41 31 24%
ml_mdm/models/nested_unet.py 115 71 38%
ml_mdm/models/unet.py 507 368 27%
ml_mdm/models/unet_mlx.py 65 7 89%
ml_mdm/reader.py 125 89 29%
ml_mdm/s3_helpers.py 56 43 23%
ml_mdm/samplers.py 354 276 22%
ml_mdm/trainer.py 52 52 0%
ml_mdm/utils/init.py 0 0 100%
ml_mdm/utils/fix_old_checkpoints.py 10 7 30%
ml_mdm/utils/simple_logger.py 85 85 0%

TOTAL 2911 2405 17%

======================================================================================= short test summary info =======================================================================================
FAILED tests/test_mlx_unet.py::test_pytorch_mlx_self_attention - ValueError: [conv] Expect the input channels in the input and weight array to match but got shapes - input: (2,64,8,8) and weight: (64,1,1,64)

@gabrielfnayres
Copy link
Contributor

gonna check it, but probably the error is because of mlx.array initialization

@gabrielfnayres
Copy link
Contributor

@bdeanhardt i think i've got this

Screenshot 2025-02-21 at 14 59 47

@bdeanhardt
Copy link
Contributor Author

@gabrielfnayres awesome!

@gabrielfnayres
Copy link
Contributor

gabrielfnayres commented Feb 21, 2025

@gabrielfnayres awesome!

should i open another PR? @luke-carlson

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants