Skip to content

Commit

Permalink
unpack through unified convolution interface (#105)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #105

Support for calling unpack using unified interface for packing convolution weights

Reviewed By: jianyuh

Differential Revision: D16190534

fbshipit-source-id: 28e1b95c7642c1cf9ed3d8935f56c740f9b44bcd
  • Loading branch information
dskhudia authored and facebook-github-bot committed Jul 15, 2019
1 parent bee229d commit 2346ccb
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 3 deletions.
6 changes: 6 additions & 0 deletions include/fbgemm/Fbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,12 @@ class FBGEMM_API PackWeightsForConv {
return W_gconv_packed_;
}

/**
* @brief Unpack packed matric into origin_buf (Used for the serialization to
* recover weight matrix).
*/
void unpack(T* origin_buf);

private:
// Packed weights if we use im2col based convolution implementation
std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_;
Expand Down
15 changes: 15 additions & 0 deletions src/PackWeightsForConv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
} // switch
}

template <int SPATIAL_DIM, typename T, typename accT>
void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) {
if (W_dw_2D_packed_) {
W_dw_2D_packed_->unpack(origin_buf);
} else if (W_dw_3D_packed_) {
W_dw_3D_packed_->unpack(origin_buf);
} else if (W_gconv_packed_) {
W_gconv_packed_->unpack(origin_buf);
} else if (W_im2col_packed_) {
W_im2col_packed_->unpack(origin_buf);
} else {
assert(false && "At least one packed weights object should exist");
}
}

template class PackWeightsForConv<2, int8_t, int32_t>;
template class PackWeightsForConv<3, int8_t, int32_t>;

Expand Down
63 changes: 60 additions & 3 deletions test/UniConvPackingTest.cc → test/UniConvTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ using namespace fbgemm;
namespace {

// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad
class convPackingTest
class uniConvTest
: public testing::TestWithParam<
tuple<int, int, int, int, int, int, int, int, int, int>> {};

}; // namespace

INSTANTIATE_TEST_CASE_P(
InstantiationName,
convPackingTest,
uniConvTest,
::testing::Combine(
::testing::ValuesIn({1, 2}), // MB
::testing::ValuesIn({16, 32}), // IC
Expand All @@ -47,7 +47,7 @@ INSTANTIATE_TEST_CASE_P(
/**
* Test for conv packing
*/
TEST_P(convPackingTest, packingTest) {
TEST_P(uniConvTest, packingTest) {
int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad;
tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam();

Expand Down Expand Up @@ -146,3 +146,60 @@ TEST_P(convPackingTest, packingTest) {
}
}
}

/**
* Test for packing/unpacking
*/
TEST_P(uniConvTest, packUnpackTest) {
int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad;
tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam();

conv_param_t<2> conv_p_2d(
MB,
IC,
OC,
{IH, IW},
G,
{kernel, kernel},
{stride, stride},
{pad, pad, pad, pad});

int kernel_dim_2d = kernel * kernel;

aligned_vector<int8_t> Bint8_2d(
kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G));
aligned_vector<int8_t> Bint8_2d_unpacked(
kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G));

PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data());

packedB_2D.unpack(Bint8_2d_unpacked.data());

ASSERT_EQ(Bint8_2d, Bint8_2d_unpacked)
<< "Original and unpacked data elements are not the same [2D]";

conv_param_t<3> conv_p_3d(
MB,
IC,
OC,
{IT, IH, IW},
G,
{kernel, kernel, kernel},
{stride, stride, stride},
{pad, pad, pad, pad, pad, pad});

int kernel_dim_3d = kernel * kernel * kernel;

aligned_vector<int8_t> Bint8_3d(
kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G));

aligned_vector<int8_t> Bint8_3d_unpacked(
kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G));

PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data());

packedB_3D.unpack(Bint8_3d_unpacked.data());

ASSERT_EQ(Bint8_3d, Bint8_3d_unpacked)
<< "Original and unpacked data elements are not the same [3D]";
}

0 comments on commit 2346ccb

Please sign in to comment.