Skip to content

Commit

Permalink
Fix data size bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jagrit06 committed Nov 22, 2024
1 parent 4640f86 commit ed4fb26
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ void ScaledDotProductAttention::eval_gpu(
size_t str_oH = o.shape(3);
size_t str_oL = o.shape(1) * str_oH;
size_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;

array::Flags flags{
/* bool contiguous = */ 1,
Expand All @@ -339,7 +340,7 @@ void ScaledDotProductAttention::eval_gpu(

o.set_data(
allocator::malloc_or_wait(o.nbytes()),
o.data_size(),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);

Expand Down

0 comments on commit ed4fb26

Please sign in to comment.