Skip to content

[BUG] Hanging when executing optimization loop more than once #1711

Closed
@aristidb

Description

Describe the bug
When using mx.compile on my loss function, if I execute the optimization loop multiple times, it usually hangs on the second time I invoke it, from a Jupyter notebook cell. Normally the cell takes 1-3 seconds to run, but when it hangs it takes at least 10 minutes.

I'm following this "MLP for names" tutorial but transcribing everything to MLX: https://www.youtube.com/watch?v=TCH_1BHY58I - so it's an MLP implementation from scratch.

To Reproduce

The bug is too confusing for me to make a minimal reproducible example at this point (I'm also an absolute beginner when it comes to machine learning), but I can show you the code that causes the hangs, as well as some lldb stack traces that should at least show where it's hanging (see "Additional context" section).

If it helps, I can also send my entire Jupyter Notebook as well as the data required to reproduce the issue. The Notebook is about 100 lines of code in total, I think.

Model parameters:

mx.random.seed(7674764)
C = mx.random.normal((27, 10))
W1 = mx.random.normal((30, 200))
b1 = mx.random.normal((200,))
W2 = mx.random.normal((200, 27))
b2 = mx.random.normal((27,))
parameters = [C, W1, b1, W2, b2]

Loss function:

def loss_fn(X: mx.array, Y: mx.array, C: mx.array, W1: mx.array, b1: mx.array, W2: mx.array, b2: mx.array):
    h = mx.tanh(C[X].flatten(1) @ W1 + b1)
    logits = h @ W2 + b2
    return nn.losses.cross_entropy(logits, Y, reduction='mean')
loss_value_and_grad = mx.compile(mx.value_and_grad(loss_fn, range(2, 7)))

print(loss_value_and_grad(Xtr, Ytr, C, W1, b1, W2, b2))

Optimization loop that usually runs in 1-3 seconds the first time and hangs the second time:

lossi = []
for i in range(50000):
    ix = mx.random.randint(0, Xtr.shape[0], (320,))  # minibatch
    loss, grad = loss_value_and_grad(Xtr[ix], Ytr[ix], C, W1, b1, W2, b2)
    # lr = lrs[i]
    for pi, p in enumerate(parameters):
        p -= lr * grad[pi]
    lossi.append(loss)

Expected behavior
I expect the optimization loop to run within less than 30 seconds.

Desktop (please complete the following information):

  • OS Version: both macOS 15.2 (24C101) and macOS 15.1.1 (24B2091)
  • MLX Version: 0.21.1
  • Python version: 3.13 [Homebrew]
  • CPU: Apple M4 Max

Additional context
Backtraces when it hangs:

(lldb) bt
* thread #1, queue = 'com.apple.main-thread', stop reason = signal SIGSTOP
  * frame #0: 0x0000000181aa1e34 libsystem_kernel.dylib`mach_msg2_trap + 8
    frame #1: 0x0000000181ab45d0 libsystem_kernel.dylib`mach_msg2_internal + 80
    frame #2: 0x000000018556c360 IOKit`io_connect_method + 524
    frame #3: 0x000000018556c128 IOKit`IOConnectCallMethod + 236
    frame #4: 0x00000001a25081e8 IOGPU`IOGPUResourceCreate + 248
    frame #5: 0x00000001a2501878 IOGPU`-[IOGPUMetalResource initWithDevice:remoteStorageResource:options:args:argsSize:] + 484
    frame #6: 0x00000001a24ef3bc IOGPU`-[IOGPUMetalBuffer initWithPrimaryBuffer:heapIndex:bufferIndex:bufferOffset:length:args:argsSize:gpuTag:] + 276
    frame #7: 0x000000011dd87998 AGXMetalG16X`-[AGXBuffer(Internal) initWithDevice:length:alignment:pointerTag:options:isSuballocDisabled:resourceInArgs:pinnedGPULocation:] + 520
    frame #8: 0x000000011dd879d8 AGXMetalG16X`-[AGXBuffer(Internal) initWithDevice:length:alignment:options:isSuballocDisabled:resourceInArgs:pinnedGPULocation:] + 44
    frame #9: 0x000000011dd875dc AGXMetalG16X`-[AGXBuffer initWithDevice:length:alignment:options:isSuballocDisabled:pinnedGPULocation:] + 32
    frame #10: 0x00000001194d72b4 libmlx.dylib`mlx::core::metal::MetalAllocator::malloc(unsigned long, bool) + 580
    frame #11: 0x0000000118c46514 libmlx.dylib`mlx::core::allocator::malloc(unsigned long) + 48
    frame #12: 0x000000010737a750 core.cpython-313-darwin.so`___lldb_unnamed_symbol1182 + 48
    frame #13: 0x000000010739f968 core.cpython-313-darwin.so`___lldb_unnamed_symbol1533 + 84
    frame #14: 0x000000010737b94c core.cpython-313-darwin.so`___lldb_unnamed_symbol1185 + 404
    frame #15: 0x00000001073a374c core.cpython-313-darwin.so`___lldb_unnamed_symbol1568 + 260
    frame #16: 0x000000010732d66c core.cpython-313-darwin.so`___lldb_unnamed_symbol647 + 2192
    frame #17: 0x0000000102cf30cc Python`vectorcall_maybe + 100
    frame #18: 0x0000000102cef44c Python`slot_nb_multiply + 168
    frame #19: 0x0000000102c3bd08 Python`binary_op1 + 228
    frame #20: 0x0000000102c3be88 Python`PyNumber_Multiply + 36
    frame #21: 0x0000000102d80b58 Python`_PyEval_EvalFrameDefault + 1688
    frame #22: 0x0000000102d80268 Python`PyEval_EvalCode + 200
    frame #23: 0x0000000102d7b8b4 Python`builtin_exec + 444
    frame #24: 0x0000000102d83780 Python`_PyEval_EvalFrameDefault + 12992
    frame #25: 0x0000000102c773d0 Python`gen_send_ex2 + 192
    frame #26: 0x0000000102c77210 Python`gen_send_ex + 36
    frame #27: 0x0000000102d85010 Python`_PyEval_EvalFrameDefault + 19280
    frame #28: 0x0000000102c5dc38 Python`method_vectorcall + 328
    frame #29: 0x0000000102c5b010 Python`_PyVectorcall_Call + 152
    frame #30: 0x0000000102d83a2c Python`_PyEval_EvalFrameDefault + 13676
    frame #31: 0x0000000102c773d0 Python`gen_send_ex2 + 192
    frame #32: 0x0000000103750e28 _asyncio.cpython-313-darwin.so`task_step_impl + 444
    frame #33: 0x0000000103750bf8 _asyncio.cpython-313-darwin.so`task_step + 64
    frame #34: 0x0000000103751e78 _asyncio.cpython-313-darwin.so`task_wakeup + 232
    frame #35: 0x0000000102cb7ecc Python`cfunction_vectorcall_O + 104
    frame #36: 0x0000000102dacd38 Python`_PyObject_VectorcallTstate + 88
    frame #37: 0x0000000102dacbf0 Python`context_run + 104
    frame #38: 0x0000000102cb7ca8 Python`cfunction_vectorcall_FASTCALL_KEYWORDS + 88
    frame #39: 0x0000000102d83a2c Python`_PyEval_EvalFrameDefault + 13676
    frame #40: 0x0000000102d80268 Python`PyEval_EvalCode + 200
    frame #41: 0x0000000102d7b8b4 Python`builtin_exec + 444
    frame #42: 0x0000000102cb7ca8 Python`cfunction_vectorcall_FASTCALL_KEYWORDS + 88
    frame #43: 0x0000000102c5b17c Python`PyObject_Vectorcall + 92
    frame #44: 0x0000000102d82a94 Python`_PyEval_EvalFrameDefault + 9684
    frame #45: 0x0000000102e14fcc Python`pymain_run_module + 228
    frame #46: 0x0000000102e145c8 Python`Py_RunMain + 732
    frame #47: 0x0000000102e14cb8 Python`pymain_main + 304
    frame #48: 0x0000000102e14d58 Python`Py_BytesMain + 40
    frame #49: 0x0000000181760274 dyld`start + 2840


(lldb) bt
* thread #1, queue = 'com.apple.main-thread', stop reason = signal SIGSTOP
  * frame #0: 0x0000000181aa1e34 libsystem_kernel.dylib`mach_msg2_trap + 8
    frame #1: 0x0000000181ab45d0 libsystem_kernel.dylib`mach_msg2_internal + 80
    frame #2: 0x000000018556c360 IOKit`io_connect_method + 524
    frame #3: 0x000000018556c128 IOKit`IOConnectCallMethod + 236
    frame #4: 0x00000001a25081e8 IOGPU`IOGPUResourceCreate + 248
    frame #5: 0x00000001a2501878 IOGPU`-[IOGPUMetalResource initWithDevice:remoteStorageResource:options:args:argsSize:] + 484
    frame #6: 0x00000001a24ef3bc IOGPU`-[IOGPUMetalBuffer initWithPrimaryBuffer:heapIndex:bufferIndex:bufferOffset:length:args:argsSize:gpuTag:] + 276
    frame #7: 0x000000012a1f7998 AGXMetalG16X`-[AGXBuffer(Internal) initWithDevice:length:alignment:pointerTag:options:isSuballocDisabled:resourceInArgs:pinnedGPULocation:] + 520
    frame #8: 0x000000012a1f79d8 AGXMetalG16X`-[AGXBuffer(Internal) initWithDevice:length:alignment:options:isSuballocDisabled:resourceInArgs:pinnedGPULocation:] + 44
    frame #9: 0x000000012a1f75dc AGXMetalG16X`-[AGXBuffer initWithDevice:length:alignment:options:isSuballocDisabled:pinnedGPULocation:] + 32
    frame #10: 0x000000010a21f2b4 libmlx.dylib`mlx::core::metal::MetalAllocator::malloc(unsigned long, bool) + 580
    frame #11: 0x000000010998e514 libmlx.dylib`mlx::core::allocator::malloc(unsigned long) + 48
    frame #12: 0x00000001099c0660 libmlx.dylib`void mlx::core::array::init<float*>(float*) + 64
    frame #13: 0x00000001099c05b0 libmlx.dylib`mlx::core::array::array<float>(float, mlx::core::Dtype) + 84
    frame #14: 0x0000000109a4b280 libmlx.dylib`mlx::core::random::uniform(mlx::core::array const&, mlx::core::array const&, std::__1::vector<int, std::__1::allocator<int>> const&, mlx::core::Dtype, std::__1::optional<mlx::core::array> const&, std::__1::variant<std::__1::monostate, mlx::core::Stream, mlx::core::Device>) + 384
    frame #15: 0x0000000109a4c3c4 libmlx.dylib`mlx::core::random::randint(mlx::core::array const&, mlx::core::array const&, std::__1::vector<int, std::__1::allocator<int>> const&, mlx::core::Dtype, std::__1::optional<mlx::core::array> const&, std::__1::variant<std::__1::monostate, mlx::core::Stream, mlx::core::Device>) + 120
    frame #16: 0x0000000108c97704 core.cpython-313-darwin.so`___lldb_unnamed_symbol1462 + 760
    frame #17: 0x0000000108c2d66c core.cpython-313-darwin.so`___lldb_unnamed_symbol647 + 2192
    frame #18: 0x00000001048e317c Python`PyObject_Vectorcall + 92
    frame #19: 0x0000000104a0d19c Python`_PyEval_EvalFrameDefault + 19676
    frame #20: 0x0000000104a08268 Python`PyEval_EvalCode + 200
    frame #21: 0x0000000104a038b4 Python`builtin_exec + 444
    frame #22: 0x0000000104a0b780 Python`_PyEval_EvalFrameDefault + 12992
    frame #23: 0x00000001048ff3d0 Python`gen_send_ex2 + 192
    frame #24: 0x00000001048ff210 Python`gen_send_ex + 36
    frame #25: 0x0000000104a0d010 Python`_PyEval_EvalFrameDefault + 19280
    frame #26: 0x00000001048e5c38 Python`method_vectorcall + 328
    frame #27: 0x00000001048e3010 Python`_PyVectorcall_Call + 152
    frame #28: 0x0000000104a0ba2c Python`_PyEval_EvalFrameDefault + 13676
    frame #29: 0x00000001048ff3d0 Python`gen_send_ex2 + 192
    frame #30: 0x00000001053d8e28 _asyncio.cpython-313-darwin.so`task_step_impl + 444
    frame #31: 0x00000001053d8bf8 _asyncio.cpython-313-darwin.so`task_step + 64
    frame #32: 0x00000001053d9e78 _asyncio.cpython-313-darwin.so`task_wakeup + 232
    frame #33: 0x000000010493fecc Python`cfunction_vectorcall_O + 104
    frame #34: 0x0000000104a34d38 Python`_PyObject_VectorcallTstate + 88
    frame #35: 0x0000000104a34bf0 Python`context_run + 104
    frame #36: 0x000000010493fca8 Python`cfunction_vectorcall_FASTCALL_KEYWORDS + 88
    frame #37: 0x0000000104a0ba2c Python`_PyEval_EvalFrameDefault + 13676
    frame #38: 0x0000000104a08268 Python`PyEval_EvalCode + 200
    frame #39: 0x0000000104a038b4 Python`builtin_exec + 444
    frame #40: 0x000000010493fca8 Python`cfunction_vectorcall_FASTCALL_KEYWORDS + 88
    frame #41: 0x00000001048e317c Python`PyObject_Vectorcall + 92
    frame #42: 0x0000000104a0aa94 Python`_PyEval_EvalFrameDefault + 9684
    frame #43: 0x0000000104a9cfcc Python`pymain_run_module + 228
    frame #44: 0x0000000104a9c5c8 Python`Py_RunMain + 732
    frame #45: 0x0000000104a9ccb8 Python`pymain_main + 304
    frame #46: 0x0000000104a9cd58 Python`Py_BytesMain + 40
    frame #47: 0x0000000181760274 dyld`start + 2840

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions