diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 911984ef4..079a0baff 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -330,6 +330,7 @@ void ScaledDotProductAttention::eval_gpu( size_t str_oH = o.shape(3); size_t str_oL = o.shape(1) * str_oH; size_t str_oB = o.shape(2) * str_oL; + size_t data_size = o.shape(0) * str_oB; array::Flags flags{ /* bool contiguous = */ 1, @@ -339,7 +340,7 @@ void ScaledDotProductAttention::eval_gpu( o.set_data( allocator::malloc_or_wait(o.nbytes()), - o.data_size(), + data_size, {str_oB, str_oH, str_oL, str_oD}, flags);