Skip to content

Commit

Permalink
Update vjp and reuse steel gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
Rifur13 committed Apr 24, 2024
1 parent 137a282 commit 0bf2696
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 97 deletions.

This file was deleted.

14 changes: 7 additions & 7 deletions mlx/backend/metal/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ void steel_matmul_conv_groups(

// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_conv_groups_" << (transpose_a ? 't' : 'n')
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
Expand Down Expand Up @@ -327,12 +327,12 @@ void steel_matmul_conv_groups(
/* const int ldd = */ ldd,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ 1,
/* const int batch_stride_b = */ 1,
/* const int batch_stride_d = */ 1,
/* const int batch_stride_a = */ K,
/* const int batch_stride_b = */ N * K,
/* const int batch_stride_d = */ N,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 0};
/* const int batch_ndim = */ 1};

// Prepare launch grid params
int tile = 1 << swizzle_log;
Expand All @@ -345,8 +345,8 @@ void steel_matmul_conv_groups(
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 3);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);

compute_encoder->dispatchThreadgroups(grid_dims, group_dims);

Expand Down
5 changes: 5 additions & 0 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,11 @@ std::vector<array> Convolution::vjp(
assert(primals.size() == 2);
std::vector<array> grads;

if (groups_ != 1) {
throw std::invalid_argument(
"[Convolution] Backward pass not implemented for groups > 1.");
}

// Collect info
auto& in = primals[0];
auto& wt = primals[1];
Expand Down

0 comments on commit 0bf2696

Please sign in to comment.