Skip to content

Commit

Permalink
Fix metal validation
Browse files Browse the repository at this point in the history
  • Loading branch information
Rifur13 committed Apr 27, 2024
1 parent 861ec7e commit 48efb11
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mlx/backend/metal/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,20 @@ void steel_matmul_conv_groups(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, groups);

std::vector<int> batch_shape = {1};
std::vector<size_t> batch_strides = {0};

// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);

compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);

compute_encoder->dispatchThreadgroups(grid_dims, group_dims);

// Clear copies
Expand Down

0 comments on commit 48efb11

Please sign in to comment.