Closed
Description
Describe the bug
I'm seeing a crash in mlx.core.fast.scaled_dot_product_attention after the change in #1610, which is in the latest version. Rolling that specific change out fixes it, but it's large so I can't say what the specific issue is.
I tried to write a reduced test case, but it doesn't seem to trigger the issue from a single invocation. I can consistently reproduce it in the context of a larger codebase when running a training loop.
import mlx.core as mx
import numpy as np
attn_inputs = np.load("attn_inputs.npz")
query = mx.array(attn_inputs["query"])
key = mx.array(attn_inputs["key"])
value = mx.array(attn_inputs["value"])
scale_factor = attn_inputs["scale_factor"]
output = mx.fast.scaled_dot_product_attention(q = query, k = key, v = value, scale = scale_factor.item())
mx.eval(output)
The captured inputs are here in case that's helpful.
Here's a backtrace from lldb when this happens:
* thread #1, queue = 'com.apple.main-thread', stop reason = EXC_BAD_ACCESS (code=1, address=0x0)
frame #0: 0x0000000196c298e0 libsystem_pthread.dylib`pthread_mutex_lock + 12
libsystem_pthread.dylib`pthread_mutex_lock:
-> 0x196c298e0 <+12>: ldr x8, [x0]
0x196c298e4 <+16>: mov w9, #0x545a ; =21594
0x196c298e8 <+20>: movk w9, #0x4d55, lsl #16
0x196c298ec <+24>: cmp x8, x9
(lldb) thread backtrace all
* thread #1, queue = 'com.apple.main-thread', stop reason = EXC_BAD_ACCESS (code=1, address=0x0)
* frame #0: 0x0000000196c298e0 libsystem_pthread.dylib`pthread_mutex_lock + 12
frame #1: 0x0000000196b66758 libc++.1.dylib`std::__1::mutex::lock() + 16
frame #2: 0x0000000101779eac libmlx.dylib`mlx::core::eval_impl(std::__1::vector<mlx::core::array, std::__1::allocator<mlx::core::array>>, bool) + 4856
frame #3: 0x000000010177a664 libmlx.dylib`mlx::core::eval(std::__1::vector<mlx::core::array, std::__1::allocator<mlx::core::array>>) + 132
frame #4: 0x00000001009df2e0 core.cpython-39-darwin.so`___lldb_unnamed_symbol1699 + 160
frame #5: 0x000000010095f664 core.cpython-39-darwin.so`___lldb_unnamed_symbol649 + 2176
frame #6: 0x000000010016f614 python`call_function + 572
frame #7: 0x000000010016bd44 python`_PyEval_EvalFrameDefault + 29268
frame #8: 0x00000001001644a8 python`_PyEval_EvalCode + 2968
frame #9: 0x000000010005fe64 python`_PyFunction_Vectorcall + 240
frame #10: 0x0000000100062cf0 python`method_vectorcall + 164
frame #11: 0x000000010016f614 python`call_function + 572
frame #12: 0x000000010016be40 python`_PyEval_EvalFrameDefault + 29520
frame #13: 0x00000001001644a8 python`_PyEval_EvalCode + 2968
frame #14: 0x00000001001c7834 python`pyrun_file + 376
frame #15: 0x00000001001c6d48 python`PyRun_SimpleFileExFlags + 816
frame #16: 0x00000001001e9e84 python`Py_RunMain + 2916
frame #17: 0x00000001001eb018 python`pymain_main + 1272
frame #18: 0x0000000100005ddc python`main + 56
frame #19: 0x00000001968ac274 dyld`start + 2840
thread #2
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #3
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #4
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #5
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #6
frame #0: 0x0000000196befa84 libsystem_kernel.dylib`__workq_kernreturn + 8
thread #7
frame #0: 0x0000000196c2a0e8 libsystem_pthread.dylib`start_wqthread
thread #8, name = 'AwsEventLoop 1'
frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #9, name = 'AwsEventLoop 2'
frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #10, name = 'AwsEventLoop 3'
frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #11, name = 'AwsEventLoop 4'
frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #12, name = 'AwsEventLoop 5'
frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #13, name = 'AwsEventLoop 6'
frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #14, name = 'AwsEventLoop 7'
frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #15, name = 'AwsEventLoop 8'
frame #0: 0x0000000196bf3efc libsystem_kernel.dylib`kevent + 8
frame #1: 0x000000012a23dc44 _c.cpython-39-darwin.so`___lldb_unnamed_symbol29409 + 448
frame #2: 0x000000012a276a2c _c.cpython-39-darwin.so`___lldb_unnamed_symbol29716 + 340
frame #3: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #16
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x00000001001d9598 python`PyThread_acquire_lock_timed + 340
frame #3: 0x00000001002478e4 python`acquire_timed + 236
frame #4: 0x0000000100247aa8 python`lock_PyThread_acquire_lock + 72
frame #5: 0x000000010006d284 python`method_vectorcall_VARARGS_KEYWORDS + 292
frame #6: 0x000000010016f614 python`call_function + 572
frame #7: 0x000000010016bd28 python`_PyEval_EvalFrameDefault + 29240
frame #8: 0x00000001001644a8 python`_PyEval_EvalCode + 2968
frame #9: 0x000000010005fe64 python`_PyFunction_Vectorcall + 240
frame #10: 0x000000010016f614 python`call_function + 572
frame #11: 0x000000010016bd28 python`_PyEval_EvalFrameDefault + 29240
frame #12: 0x00000001001644a8 python`_PyEval_EvalCode + 2968
frame #13: 0x000000010005fe64 python`_PyFunction_Vectorcall + 240
frame #14: 0x000000010016f614 python`call_function + 572
frame #15: 0x000000010016bd28 python`_PyEval_EvalFrameDefault + 29240
frame #16: 0x000000010005fee4 python`function_code_fastcall + 116
frame #17: 0x000000010016f614 python`call_function + 572
frame #18: 0x000000010016bd28 python`_PyEval_EvalFrameDefault + 29240
frame #19: 0x000000010005fee4 python`function_code_fastcall + 116
frame #20: 0x000000010016f614 python`call_function + 572
frame #21: 0x000000010016bd28 python`_PyEval_EvalFrameDefault + 29240
frame #22: 0x000000010005fee4 python`function_code_fastcall + 116
frame #23: 0x0000000100062d9c python`method_vectorcall + 336
frame #24: 0x0000000100246d50 python`t_bootstrap + 180
frame #25: 0x00000001001d91c8 python`pythread_wrapper + 48
frame #26: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #17
frame #0: 0x0000000196befb08 libsystem_kernel.dylib`kevent_id + 8
frame #1: 0x0000000196a9f5a4 libdispatch.dylib`_dispatch_kq_poll + 228
frame #2: 0x0000000196a9eae4 libdispatch.dylib`_dispatch_event_loop_poke + 340
frame #3: 0x00000001a1a00454 Metal`-[_MTLCommandQueue commitCommandBuffer:wake:] + 256
frame #4: 0x00000001b76216bc IOGPU`-[IOGPUMetalCommandBuffer commit] + 224
frame #5: 0x0000000100e023b0 AGXMetalG15X_M1`-[AGXG15XFamilyCommandBuffer commit] + 872
frame #6: 0x0000000101f71164 libmlx.dylib`mlx::core::metal::Device::commit_command_buffer(int) + 240
frame #7: 0x0000000101fa1a68 libmlx.dylib`std::__1::__function::__func<mlx::core::metal::make_task(mlx::core::array, bool)::$_0, std::__1::allocator<mlx::core::metal::make_task(mlx::core::array, bool)::$_0>, void ()>::operator()() + 1736
frame #8: 0x0000000101777344 libmlx.dylib`mlx::core::scheduler::StreamThread::thread_fn() + 488
frame #9: 0x0000000101777500 libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*>>(void*) + 72
frame #10: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #18
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x00000001017771e8 libmlx.dylib`mlx::core::scheduler::StreamThread::thread_fn() + 140
frame #4: 0x0000000101777500 libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*>>(void*) + 72
frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #19
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x00000001017771e8 libmlx.dylib`mlx::core::scheduler::StreamThread::thread_fn() + 140
frame #4: 0x0000000101777500 libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*>>(void*) + 72
frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #20
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #21
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #22
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #23
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x0000000101f020a8 libmlx.dylib`ThreadPool::ThreadPool(unsigned long)::'lambda'()::operator()() const + 144
frame #4: 0x0000000101f01f8c libmlx.dylib`void* std::__1::__thread_proxy[abi:v160006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, ThreadPool::ThreadPool(unsigned long)::'lambda'()>>(void*) + 52
frame #5: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #24
frame #0: 0x0000000196befa84 libsystem_kernel.dylib`__workq_kernreturn + 8
thread #25
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x0000000129d2ce54 _c.cpython-39-darwin.so`___lldb_unnamed_symbol22431 + 172
frame #4: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #26
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x0000000129d2ce54 _c.cpython-39-darwin.so`___lldb_unnamed_symbol22431 + 172
frame #4: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
thread #27
frame #0: 0x0000000196bf15cc libsystem_kernel.dylib`__psynch_cvwait + 8
frame #1: 0x0000000196c2f894 libsystem_pthread.dylib`_pthread_cond_wait + 1204
frame #2: 0x0000000196b65578 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
frame #3: 0x0000000129d2ce54 _c.cpython-39-darwin.so`___lldb_unnamed_symbol22431 + 172
frame #4: 0x0000000196c2f2e4 libsystem_pthread.dylib`_pthread_start + 136
Desktop (please complete the following information):
- OS Version: macOS 15.1.1
- Version: 0.21.0
Metadata
Metadata
Assignees
Labels
No labels
Activity