Skip to content

[BUG] Crash in mlx.core.fast.scaled_dot_product_attention with matrix attention kernel #1643

Closed
@lucasnewman

Description

@lucasnewman

Describe the bug

I'm seeing a crash in mlx.core.fast.scaled_dot_product_attention after the change in #1610, which is in the latest version. Rolling that specific change out fixes it, but it's large so I can't say what the specific issue is.

I tried to write a reduced test case, but it doesn't seem to trigger the issue from a single invocation. I can consistently reproduce it in the context of a larger codebase when running a training loop.

import mlx.core as mx
import numpy as np

attn_inputs = np.load("attn_inputs.npz")
query = mx.array(attn_inputs["query"])
key = mx.array(attn_inputs["key"])
value = mx.array(attn_inputs["value"])
scale_factor = attn_inputs["scale_factor"]

output = mx.fast.scaled_dot_product_attention(q = query, k = key, v = value, scale = scale_factor.item())
mx.eval(output)

The captured inputs are here in case that's helpful.

Here's a backtrace from lldb when this happens:

* thread #1, queue = 'com.apple.main-thread', stop reason = EXC_BAD_ACCESS (code=1, address=0x0)
    frame #0: 0x0000000196c298e0 libsystem_pthread.dylib`pthread_mutex_lock + 12
libsystem_pthread.dylib`pthread_mutex_lock:
->  0x196c298e0 <+12>: ldr    x8, [x0]
    0x196c298e4 <+16>: mov    w9, #0x545a ; =21594 
    0x196c298e8 <+20>: movk   w9, #0x4d55, lsl #16
    0x196c298ec <+24>: cmp    x8, x9
(lldb) thread backtrace all
* thread #1, queue = 'com.apple.main-thread', stop reason = EXC_BAD_ACCESS (code=1, address=0x0)
  * frame #0: 0x0000000196c298e0 libsystem_pthread.dylib`pthread_mutex_lock + 12
    frame #1: 0x0000000196b66758 libc++.1.dylib`std::__1::mutex::lock() + 16
    frame #2: 0x0000000101779eac libmlx.dylib`mlx::core::eval_impl(std::__1::vector<mlx::core::array, std::__1::allocator<mlx::core::array>>, bool) + 4856
    frame #3: 0x000000010177a664 libmlx.dylib`mlx::core::eval(std::__1::vector<mlx::core::array, std::__1::allocator<mlx::core::array>>) + 132
    frame #4: 0x00000001009df2e0 core.cpython-39-darwin.so`___lldb_unnamed_symbol1699 + 160
    frame #5: 0x000000010095f664 core.cpython-39-darwin.so`___lldb_unnamed_symbol649 + 2176
    frame #6: 0x000000010016f614 python`call_function + 572
    frame #7: 0x000000010016bd44 python`_PyEval_EvalFrameDefault + 29268
    frame #8: 0x00000001001644a8 python`_PyEval_EvalCode + 2968
    frame #9: 0x000000010005fe64 python`_PyFunction_Vectorcall + 240
    frame #10: 0x0000000100062cf0 python`method_vectorcall + 164
    frame #11: 0x000000010016f614 python`call_function + 572
    frame #12: 0x000000010016be40 python`_PyEval_EvalFrameDefault + 29520
    frame #13: 0x00000001001644a8 python`_PyEval_EvalCode + 2968
    frame #14: 0x00000001001c7834 python`pyrun_file + 376
    frame #15: 0x00000001001c6d48 python`PyRun_SimpleFileExFlags + 816
    frame #16: 0x00000001001e9e84 python`Py_RunMain + 2916
    frame #17: 0x00000001001eb018 python`pymain_main + 1272
    frame #18: 0x0000000100005ddc python`main + 56
    frame #19: 0x00000001968ac274 dyld`start + 2840
  thread #2
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
    frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
    frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #3
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
    frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
    frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #4
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
    frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
    frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #5
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
    frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
    frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #6
    frame #0: 0x0000000196befa84 libsystem_kernel.dylib`__workq_kernreturn + 8
  thread #7
    frame #0: 0x0000000196c2a0e8 libsystem_pthread.dylib`start_wqthread
  thread #8, name = 'AwsEventLoop 1'
    frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
    frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
    frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
    frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #9, name = 'AwsEventLoop 2'
    frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
    frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
    frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
    frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #10, name = 'AwsEventLoop 3'
    frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
    frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
    frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
    frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #11, name = 'AwsEventLoop 4'
    frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
    frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
    frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
    frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #12, name = 'AwsEventLoop 5'
    frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
    frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
    frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
    frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #13, name = 'AwsEventLoop 6'
    frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
    frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
    frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
    frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #14, name = 'AwsEventLoop 7'
    frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
    frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
    frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
    frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #15, name = 'AwsEventLoop 8'
    frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
    frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
    frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
    frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #16
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x00000001001d9598 python`PyThread_acquire_lock_timed + 340
    frame #3: 0x00000001002478e4 python`acquire_timed + 236
    frame #4: 0x0000000100247aa8 python`lock_PyThread_acquire_lock + 72
    frame #5: 0x000000010006d284 python`method_vectorcall_VARARGS_KEYWORDS + 292
    frame #6: 0x000000010016f614 python`call_function + 572
    frame #7: 0x000000010016bd28 python`_PyEval_EvalFrameDefault + 29240
    frame #8: 0x00000001001644a8 python`_PyEval_EvalCode + 2968
    frame #9: 0x000000010005fe64 python`_PyFunction_Vectorcall + 240
    frame #10: 0x000000010016f614 python`call_function + 572
    frame #11: 0x000000010016bd28 python`_PyEval_EvalFrameDefault + 29240
    frame #12: 0x00000001001644a8 python`_PyEval_EvalCode + 2968
    frame #13: 0x000000010005fe64 python`_PyFunction_Vectorcall + 240
    frame #14: 0x000000010016f614 python`call_function + 572
    frame #15: 0x000000010016bd28 python`_PyEval_EvalFrameDefault + 29240
    frame #16: 0x000000010005fee4 python`function_code_fastcall + 116
    frame #17: 0x000000010016f614 python`call_function + 572
    frame #18: 0x000000010016bd28 python`_PyEval_EvalFrameDefault + 29240
    frame #19: 0x000000010005fee4 python`function_code_fastcall + 116
    frame #20: 0x000000010016f614 python`call_function + 572
    frame #21: 0x000000010016bd28 python`_PyEval_EvalFrameDefault + 29240
    frame #22: 0x000000010005fee4 python`function_code_fastcall + 116
    frame #23: 0x0000000100062d9c python`method_vectorcall + 336
    frame #24: 0x0000000100246d50 python`t_bootstrap + 180
    frame #25: 0x00000001001d91c8 python`pythread_wrapper + 48
    frame #26: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #17
    frame #0: 0x0000000196befb08 libsystem_kernel.dylib`kevent_id + 8
    frame #1: 0x0000000196a9f5a4 libdispatch.dylib`_dispatch_kq_poll + 228
    frame #2: 0x0000000196a9eae4 libdispatch.dylib`_dispatch_event_loop_poke + 340
    frame #3: 0x00000001a1a00454 Metal`-[_MTLCommandQueue commitCommandBuffer:wake:] + 256
    frame #4: 0x00000001b76216bc IOGPU`-[IOGPUMetalCommandBuffer commit] + 224
    frame #5: 0x0000000100e023b0 AGXMetalG15X_M1`-[AGXG15XFamilyCommandBuffer commit] + 872
    frame #6: 0x0000000101f71164 libmlx.dylib`mlx::core::metal::Device::commit_command_buffer(int) + 240
    frame #7: 0x0000000101fa1a68 libmlx.dylib`std::__1::__function::__func<mlx::core::metal::make_task(mlx::core::array, bool)::$_0, std::__1::allocator<mlx::core::metal::make_task(mlx::core::array, bool)::$_0>, void ()>::operator()() + 1736
    frame #8: 0x0000000101777344 libmlx.dylib`mlx::core::scheduler::StreamThread::thread_fn() + 488
    frame #9: 0x0000000101777500 libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*>>(void*) + 72
    frame #10: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #18
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x00000001017771e8 libmlx.dylib`mlx::core::scheduler::StreamThread::thread_fn() + 140
    frame #4: 0x0000000101777500 libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*>>(void*) + 72
    frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #19
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x00000001017771e8 libmlx.dylib`mlx::core::scheduler::StreamThread::thread_fn() + 140
    frame #4: 0x0000000101777500 libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*>>(void*) + 72
    frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #20
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
    frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
    frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #21
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
    frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
    frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #22
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
    frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
    frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #23
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
    frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
    frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #24
    frame #0: 0x0000000196befa84 libsystem_kernel.dylib`__workq_kernreturn + 8
  thread #25
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x0000000129d2ce54 _c.cpython-39-darwin.so`___lldb_unnamed_symbol22431 + 172
    frame #4: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #26
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x0000000129d2ce54 _c.cpython-39-darwin.so`___lldb_unnamed_symbol22431 + 172
    frame #4: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
  thread #27
    frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
    frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x0000000129d2ce54 _c.cpython-39-darwin.so`___lldb_unnamed_symbol22431 + 172
    frame #4: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136

Desktop (please complete the following information):

  • OS Version: macOS 15.1.1
  • Version: 0.21.0

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions