Skip to content

Commit

Permalink
Add conv1d grouped convs on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
Rifur13 committed Apr 1, 2024
1 parent 02fedbf commit da3c5fd
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 67 deletions.
9 changes: 9 additions & 0 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ class array {
return array_desc_->strides;
};

/**
* Get the stride of the corresponding dimension.
*
* This function supports negative indexing and provides
* bounds checking. */
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
};

/** Get the arrays data type. */
Dtype dtype() const {
return array_desc_->dtype;
Expand Down
142 changes: 85 additions & 57 deletions mlx/backend/common/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ void slow_conv_1D(

const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim

const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;

const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_C = in.strides()[2];
Expand All @@ -57,35 +61,36 @@ void slow_conv_1D(

for (int n = 0; n < N; ++n) {
for (int oh = 0; oh < oH; ++oh) {
for (int o = 0; o < O; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;

for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;

int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];

auto ih_div = std::div(ih, in_dilation[0]);
auto ih_div = std::div(ih, in_dilation[0]);

if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = 0; c < C; ++c) {
r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[c * wt_stride_C]);
} // c
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[(c % C_per_group) * wt_stride_C]);
} // c

} // ih check
} // wh
} // ih check
} // wh

out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
} // g
} // oh

in_ptr += in_stride_N;
out_ptr += out_stride_N;

} // n
}

Expand Down Expand Up @@ -366,11 +371,15 @@ void explicit_gemm_conv_1D_cpu(
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim

const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;

auto conv_dtype = float32;

// Pad input
Expand All @@ -394,61 +403,80 @@ void explicit_gemm_conv_1D_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);

// Make strided view
std::vector<int> strided_shape = {N, oH, wH, C};
std::vector<int> strided_shape = {N, oH, C, wH};

std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[1],
in_padded.strides()[2]};
in_padded.strides()[2],
in_padded.strides()[1]};
auto flags = in_padded.flags();

array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);

// Materialize strided view
std::vector<int> strided_reshape = {N * oH, wH * C};
std::vector<int> strided_reshape = {N * oH, C * wH};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);

// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
// Transpose kernels weights (O, wH, C_per_group) -> (O, C_per_group, wH) to
// align with the input.
array wt_transpose(
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
wt_transpose.copy_shared_buffer(
wt,
{wt.strides(0), wt.strides(2), wt.strides(1)},
{0, 0, 0}, // Flags
wt.size(),
0);
auto gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});

if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype);
}
// Ensure contiguity
copy(wt_transpose, gemm_wt, CopyType::General);

if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
const float* in_ptr = in_strided.data<float>();
const float* wt_ptr = gemm_wt.data<float>();

auto gemm_out = out;
if (out.dtype() != float32 || groups > 1) {
gemm_out = array({N, oH, O_per_group}, float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}

// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);

// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
for (int g = 0; g < groups; ++g) {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O_per_group, // N
C_per_group * wH, // K
1.0f, // alpha
in_ptr + (g * C_per_group * wH), // input group
wH * C, // lda
wt_ptr + (g * O_per_group * C_per_group * wH), // filters group
wH * C_per_group, // ldb
0.0f, // beta
gemm_out.data<float>(), // output group
O_per_group // ldc
);

// Copy results if needed
if (out.dtype() != float32 || groups > 1) {
array out_slice(gemm_out.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out,
out.strides(),
out.flags(),
gemm_out.size(),
g * O_per_group /* offset */);

// Copy the result of one grouped convolution into the slice.
copy_inplace(gemm_out, out_slice, CopyType::GeneralGeneral);
}
}
}

Expand Down
44 changes: 35 additions & 9 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ array reshape(
"[reshape] Cannot infer the shape of an empty array");
}

// Check the the reshaping is valid
// Check that the reshaping is valid
if (a.size() != size) {
std::ostringstream msg;
msg << "[reshape] Cannot reshape array of size " << a.size()
Expand Down Expand Up @@ -2847,7 +2847,8 @@ inline std::vector<int> conv_out_shape(
return out_shape;
}

inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
inline void
run_conv_checks(const array& in, const array& wt, int n_dim, int groups) {
if (!issubdtype(in.dtype(), floating)) {
std::ostringstream msg;
msg << "[conv] Invalid input array with type " << in.dtype() << "."
Expand All @@ -2873,11 +2874,35 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
throw std::invalid_argument(msg.str());
}

if (in.shape(n_dim + 1) != wt.shape(n_dim + 1)) {
if (in.shape(n_dim + 1) % groups != 0) {
std::ostringstream msg;
msg << "[conv] Expect the input channels in the input"
<< " and weight array to match but got shapes -"
<< " input: " << in.shape() << " and weight: " << wt.shape();
msg << "[conv] The input channels must be divisible by the number"
<< " of groups. Got input with shape " << in.shape() << "and " << groups
<< " groups.";
throw std::invalid_argument(msg.str());
}

if (groups > 1 && wt.shape(0) % groups != 0) {
std::ostringstream msg;
msg << "[conv] If groups > 1, the output channels must be divisible by the number"
<< " of groups. Got " << wt.shape(0) << " output channels and "
<< groups << " groups.";
throw std::invalid_argument(msg.str());
}

if (in.shape(n_dim + 1) != (groups * wt.shape(n_dim + 1))) {
std::ostringstream msg;
if (groups == 1) {
msg << "[conv] Expect the input channels in the input"
<< " and weight array to match but got shapes -"
<< " input: " << in.shape() << " and weight: " << wt.shape();

} else {
msg << "Given groups=" << groups << " and weights of shape " << wt.shape()
<< ", expected to have " << (groups * wt.shape(n_dim + 1))
<< " input channels but got " << in.shape(n_dim + 1)
<< " input channels instead.";
}
throw std::invalid_argument(msg.str());
}
}
Expand Down Expand Up @@ -2940,8 +2965,9 @@ array conv_general(
bool flip /* = false */,
StreamOrDevice s /* = {} */) {
// Run checks
if (groups != 1) {
throw std::invalid_argument("[conv] Cannot handle groups != 1 yet");
if (groups != 1 && in.ndim() != 3) {
throw std::invalid_argument(
"[conv] Can only handle groups != 1 in 1D convolutions.");
}

int spatial_dims = in.ndim() - 2;
Expand All @@ -2953,7 +2979,7 @@ array conv_general(
}

// Run checks
run_conv_checks(in, wt, spatial_dims);
run_conv_checks(in, wt, spatial_dims, groups);

// Type promotion
auto out_type = promote_types(in.dtype(), wt.dtype());
Expand Down
9 changes: 8 additions & 1 deletion python/tests/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def run_conv1D(
np_dtype = getattr(np, dtype)
np.random.seed(0)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(
np_dtype
)

in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
Expand Down Expand Up @@ -119,6 +121,11 @@ def run_conv1D(
):
run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype)

# groups tests
N, C, O = (4, 32, 64)
iH, kH, stride, padding = (31, 5, 1, 2)
for group in (1, 2, 4, 8, 16, 32):
run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype)
# Strided inputs tests
for tpose_in, tpose_wt in (
((0, 2, 1), (0, 1, 2)),
Expand Down
Loading

0 comments on commit da3c5fd

Please sign in to comment.