Skip to content

Commit

Permalink
Add Im2ColMobileFunctor.
Browse files Browse the repository at this point in the history
  • Loading branch information
hedaoyuan committed Dec 26, 2017
1 parent dbf1d75 commit d775895
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 28 deletions.
56 changes: 28 additions & 28 deletions paddle/function/GemmConvOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ class GemmConvMobileFunction : public ConvFunctionBase {
colData = reinterpret_cast<real*>(memory_->getBuf());
}

Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
Im2ColMobileFunctor<real> im2col;
size_t inputOffset = imShape.getElements();
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
Expand Down Expand Up @@ -241,39 +240,40 @@ class GemmConvMobileFunction : public ConvFunctionBase {

// gemm
int M = outputChannels / groups_;
gemm(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
1.0f,
filterData + g * filterOffset + colHeightStart,
kStride,
colData,
N,
beta_,
outputData + g * outputOffset + colWidthStart,
nStride);
BlasGemm<Device, real>::compute(
false,
false,
M,
N,
K,
1.0f,
filterData + g * filterOffset + colHeightStart,
kStride,
colData,
N,
beta_,
outputData + g * outputOffset + colWidthStart,
nStride);
}
beta_ = 1.0;
}
} else {
int M = outputChannels / groups_;
int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
1.0f,
filterData + g * filterOffset,
K,
inputData + g * inputOffset,
N,
beta,
outputData + g * outputOffset,
N);
BlasGemm<Device, real>::compute(false,
false,
M,
N,
K,
1.0f,
filterData + g * filterOffset,
K,
inputData + g * inputOffset,
N,
beta,
outputData + g * outputOffset,
N);
}
}
inputData += inputChannels * inputHeight * inputWidth;
Expand Down
48 changes: 48 additions & 0 deletions paddle/function/Im2Col.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,52 @@ class Col2ImFunctor {
int dilationWidth = 1);
};

template <class T>
class Im2ColMobileFunctor {
public:
void operator()(const T* imData,
const TensorShape& imShape,
T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int colHeightStart,
int colHeightSize,
int colWidthStart,
int colWidthSize) {
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[1];
int filterWidth = colShape[2];
int outputWidth = colShape[4];

for (int colh = 0; colh < colHeightSize; colh++) {
int wOffset = (colHeightStart + colh) % filterWidth;
int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight;
int c_im = (colHeightStart + colh) / filterWidth / filterHeight;

for (int colw = 0; colw < colWidthSize; colw++) {
int h = (colWidthStart + colw) / outputWidth;
int w = (colWidthStart + colw) % outputWidth;

int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 ||
(imColIdx - paddingWidth) >= inputWidth) {
colData[colh * colWidthSize + colw] = T(0);
} else {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
colData[colh * colWidthSize + colw] =
imData[imRowIdx * inputWidth + imColIdx];
}
}
}
}
};

} // namespace paddle

0 comments on commit d775895

Please sign in to comment.