Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add groups to Conv1d #948

Merged
merged 10 commits into from
Apr 27, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove copy and refactor
  • Loading branch information
Rifur13 committed Apr 27, 2024
commit 13499e9163e4aaa638519aae270baade614957a3
76 changes: 38 additions & 38 deletions mlx/backend/common/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,45 +403,54 @@ void explicit_gemm_conv_1D_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);

// Make strided view
std::vector<int> strided_shape = {N, oH, C, wH};
std::vector<int> strided_shape = {N, oH, wH, C};

std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2],
in_padded.strides()[1]};
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
std::swap(strided_shape[2], strided_shape[3]);
std::swap(strided_strides[2], strided_strides[3]);
}

array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);

// Materialize strided view
std::vector<int> strided_reshape = {N * oH, C * wH};
std::vector<int> strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);

// Transpose kernels weights (O, wH, C_per_group) -> (O, C_per_group, wH) to
// align with the input.
array wt_transpose(
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
wt_transpose.copy_shared_buffer(
wt,
{wt.strides(0), wt.strides(2), wt.strides(1)},
wt.flags(),
wt.size(),
0);
auto gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});

// Ensure contiguity
copy(wt_transpose, gemm_wt, CopyType::General);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;

const float* in_ptr = in_strided.data<float>();
const float* wt_ptr = gemm_wt.data<float>();
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
array wt_transpose(
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
wt_transpose.copy_shared_buffer(
wt,
{wt.strides(0), wt.strides(2), wt.strides(1)},
wt.flags(),
wt.size(),
0);
auto gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy(wt_transpose, gemm_wt, CopyType::General);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype);
}

auto gemm_out = out;
if (out.dtype() != float32 || groups > 1) {
gemm_out = array({N, oH, O_per_group}, float32, nullptr, {});
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}

Expand All @@ -455,27 +464,18 @@ void explicit_gemm_conv_1D_cpu(
O_per_group, // N
C_per_group * wH, // K
1.0f, // alpha
in_ptr + (g * C_per_group * wH), // input group
in_strided.data<float>() + g * C_per_group * wH, // A
wH * C, // lda
wt_ptr + (g * O_per_group * C_per_group * wH), // filters group
gemm_wt.data<float>() + g * O_per_group * C_per_group * wH, // B
wH * C_per_group, // ldb
0.0f, // beta
gemm_out.data<float>(), // output group
O_per_group // ldc
gemm_out.data<float>() + g * O_per_group, // C
O // ldc
);

// Copy results if needed
if (out.dtype() != float32 || groups > 1) {
array out_slice(gemm_out.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out,
out.strides(),
out.flags(),
gemm_out.size(),
g * O_per_group /* offset */);

// Copy the result of one grouped convolution into the slice.
copy_inplace(gemm_out, out_slice, CopyType::GeneralGeneral);
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
}
Expand Down