Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make more C++ unit tests work for batch norm #28

Merged
merged 33 commits into from
Feb 12, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
some backward items initialized
  • Loading branch information
Olivier committed Feb 7, 2018
commit af393cc3feb008e507a5b847cf2f4a6612f4e4aa
10 changes: 10 additions & 0 deletions tests/cpp/include/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ inline void fill(const TBlob& blob, const DType val) {
}
}

template<typename DType>
inline void try_fill(const TBlob *blob, const DType val) {
if(blob) {
DType *p1 = blob->dptr<DType>();
for (size_t i = 0, n = blob->Size(); i < n; ++i) {
*p1++ = val;
}
}
}

template<typename DType>
inline void fill(const TBlob& blob, const DType *valArray) {
DType *p1 = blob.dptr<DType>();
Expand Down
52 changes: 37 additions & 15 deletions tests/cpp/operator/batchnorm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ class BNOperatorExecutor : public test::op::CoreOpExecutor<DType, AccReal> {
using Super = typename test::op::CoreOpExecutor<DType, AccReal>;
public:
BNOperatorExecutor(const bool isGPU, const TShape& inputShape,
const test::op::kwargs_t& kwargs,
const bool hasWeightAndBias = false)
: test::op::CoreOpExecutor<DType, AccReal>(isGPU, { inputShape })
, hasWeightAndBias_(hasWeightAndBias) {
param_.Init(kwargs);
}

//using BlobVectorType = typename test::op::CoreOpExecutor<DType, AccReal>::BlobVectorType;
Expand Down Expand Up @@ -104,30 +106,43 @@ class BNOperatorExecutor : public test::op::CoreOpExecutor<DType, AccReal> {
return &arrs[idx];
}

const NDArray *GetBackwardInArray(const op::BatchNormParam& param, const int idx) const {
const NDArray *GetBackwardInArray(const int idx) const {
const std::vector<NDArray> &arrs = Super::bwd_inputs();
switch (idx) {
case kBackOutGrad:
CHECK_LT(kBackOutGrad, arrs.size());
return &arrs[kBackOutGrad];
case kBackOutGradMean:
if (param.output_mean_var) {
if (param_.output_mean_var) {
CHECK_LT(kBackOutGradMean, arrs.size());
return &arrs[kBackOutGradMean];
} else {
CHECK(false);
return nullptr;
}
case kBackOutGradVar:
if (param.output_mean_var) {
if (param_.output_mean_var) {
return &arrs[kBackOutGradVar];
} else {
CHECK(false);
return nullptr;
}
default:
return &arrs[param.output_mean_var ? idx : idx - 2];
default: {
const size_t index = param_.output_mean_var ? idx : idx - 2;
if(index < arrs.size()) {
return &arrs[index];
}
return nullptr;
}
}
}

const TBlob *GetBackwardInBlob(const int idx) const {
const NDArray * arr = GetBackwardInArray(idx);
if(arr) {
return &arr->data();
}
return nullptr;
}

const NDArray *GetArray(const WhichArray wa, const int idx) const {
Expand Down Expand Up @@ -234,31 +249,37 @@ class BNOperatorExecutor : public test::op::CoreOpExecutor<DType, AccReal> {
const int dtype = out.type_flag_;
MSHADOW_TYPE_SWITCH(dtype, DTypeX, { test::fill(out, DTypeX(0.5678)); });
}
/*
DType val = -.001;
MSHADOW_TYPE_SWITCH(
this->c_.blob_out_grad_[mxnet::op::batchnorm::kOut].type_flag_,
GetBlob(kBackwardIn, kBackOutGrad).type_flag_,
//this->c_.blob_out_grad_[mxnet::op::batchnorm::kOut].type_flag_,
DTypeX, {
test::patternFill<DTypeX>(&this->c_.blob_out_grad_[mxnet::op::batchnorm::kOut],
test::patternFill<DTypeX>(
&GetBlob(kBackwardIn, kBackOutGrad),
//&this->c_.blob_out_grad_[mxnet::op::batchnorm::kOut],
[&val]{ return val += 1; });
});

// out-grad weights
if (mxnet::op::batchnorm::kGamma < this->c_.blob_out_grad_.size()) {
//if (mxnet::op::batchnorm::kGamma < this->c_.blob_out_grad_.size()) {
if (GetBackwardInBlob(kBackGamma)) {
MSHADOW_TYPE_SWITCH(
this->c_.blob_out_grad_[mxnet::op::batchnorm::kGamma].type_flag_,
GetBackwardInBlob(kBackGamma)->type_flag_,
//this->c_.blob_out_grad_[mxnet::op::batchnorm::kGamma].type_flag_,
DTypeX,
{ test::try_fill(this->c_.blob_out_grad_, mxnet::op::batchnorm::kGamma, DTypeX(0.1)); });
{ test::try_fill(GetBackwardInBlob(kBackGamma), DTypeX(0.1)); });
}

// out-grad biases
if (mxnet::op::batchnorm::kBeta < this->c_.blob_out_grad_.size()) {
if (GetBackwardInBlob(kBackBeta)) {
MSHADOW_TYPE_SWITCH(
this->c_.blob_out_grad_[mxnet::op::batchnorm::kBeta].type_flag_,
GetBackwardInBlob(kBackBeta)->type_flag_,
//this->c_.blob_out_grad_[mxnet::op::batchnorm::kGamma].type_flag_,
DTypeX,
{ test::try_fill(this->c_.blob_out_grad_, mxnet::op::batchnorm::kBeta, DTypeX(0.1)); });
{ test::try_fill(GetBackwardInBlob(kBackBeta), DTypeX(0.1)); });
}

/*
// in-grad
MSHADOW_TYPE_SWITCH(
this->c_.blob_in_grad_[mxnet::op::batchnorm::kData].type_flag_,
Expand All @@ -284,6 +305,7 @@ class BNOperatorExecutor : public test::op::CoreOpExecutor<DType, AccReal> {
}

const bool hasWeightAndBias_; // This will cause forward pass validation to fail
op::BatchNormParam param_;
};

/*! \brief Validate batch norm test outputs */
Expand Down Expand Up @@ -661,7 +683,7 @@ static test::op::OpInfo<OperatorProp, OperatorExecutor> TestBatchNormOperatorFor
test::op::OpInfo<OperatorProp, OperatorExecutor> info = test::op::createOpAndInfoF<
OperatorProp, OperatorExecutor>(
OperatorExecutor::ArgsWithOpName(kwargs, "BatchNorm", "_backward_BatchNorm"),
isGPU, inputShape);
isGPU, inputShape, kwargs);

info.executor_->initForward(*info.prop_, &info.in_type_);

Expand Down