Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSRA weight filler #1946

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Added MSRAFiller, which implements an Xavier-like filler designed for…
… use

with ReLUs instead of tanh. Based on paper: He et al, "Delving Deep into
Rectifiers: Surpassing Human-Level Performance on ImageNet Classification,"
2015. Added VarianceNorm option to FillerParameters which allows one to
normalize by fan_in, fan_out or their average. Updated XavierFiller to use the
VarianceNorm option (default behavior unchanged). Added tests for MSRAFiller and
XavierFiller.
  • Loading branch information
Nick Carlevaris-Bianco committed Feb 23, 2015
commit 1aac6b8598353c869c57651314313d4d8285c899
68 changes: 59 additions & 9 deletions include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,18 @@ class PositiveUnitballFiller : public Filler<Dtype> {
};

/**
* @brief Fills a Blob with values @f$ x \sim U(-a, +a) @f$ where @f$ a @f$
* is set inversely proportional to the number of incoming nodes.
* @brief Fills a Blob with values @f$ x \sim U(-a, +a) @f$ where @f$ a @f$ is
* set inversely proportional to number of incoming nodes, outgoing
* nodes, or their average.
*
* A Filler based on the paper [Bengio and Glorot 2010]: Understanding
* the difficulty of training deep feedforward neuralnetworks, but does not
* use the fan_out value.
* the difficulty of training deep feedforward neuralnetworks.
*
* It fills the incoming matrix by randomly sampling uniform data from
* [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
* of input nodes. You should make sure the input blob has shape (num, a, b, c)
* where a * b * c = fan_in.
* It fills the incoming matrix by randomly sampling uniform data from [-scale,
* scale] where scale = sqrt(3 / n) where n is the fan_in, fan_out, or their
* average, depending on the variance_norm option. You should make sure the
* input blob has shape (num, a, b, c) where a * b * c = fan_in and num * b * c
* = fan_out. Note that this is currently not the case for inner product layers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1970 is in so this filler is now right for InnerProduct layers too.

*
* TODO(dox): make notation in above comment consistent with rest & use LaTeX.
*/
Expand All @@ -149,14 +150,61 @@ class XavierFiller : public Filler<Dtype> {
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->num();
Dtype scale = sqrt(Dtype(3) / fan_in);
int fan_out = blob->count() / blob->channels();
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
n = (fan_in + fan_out) / Dtype(2);
} else if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_FAN_OUT) {
n = fan_out;
}
Dtype scale = sqrt(Dtype(3) / n);
caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
blob->mutable_cpu_data());
CHECK_EQ(this->filler_param_.sparse(), -1)
<< "Sparsity not supported by this Filler.";
}
};

/**
* @brief Fills a Blob with values @f$ x \sim N(0, \sigma^2) @f$ where
* @f$ \sigma^2 @f$ is set inversely proportional to number of incoming
* nodes, outgoing nodes, or their average.
*
* A Filler based on the paper [He, Zhang, Ren and Sun 2015]: Specifically
* accounts for ReLU nonlinearities.
*
* It fills the incoming matrix by randomly sampling Gaussian data with std =
* sqrt(2 / n) where n is the fan_in, fan_out, or their average, depending on
* the variance_norm option. You should make sure the input blob has shape (num,
* a, b, c) where a * b * c = fan_in and num * b * c = fan_out. Note that this
* is currently not the case for inner product layers.
*/
template <typename Dtype>
class MSRAFiller : public Filler<Dtype> {
public:
explicit MSRAFiller(const FillerParameter& param)
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->num();
int fan_out = blob->count() / blob->channels();
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
n = (fan_in + fan_out) / Dtype(2);
} else if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_FAN_OUT) {
n = fan_out;
}
Dtype std = sqrt(Dtype(2) / n);
caffe_rng_gaussian<Dtype>(blob->count(), Dtype(0), std,
blob->mutable_cpu_data());
CHECK_EQ(this->filler_param_.sparse(), -1)
<< "Sparsity not supported by this Filler.";
}
};

/**
* @brief Get a specific filler from the specification given in FillerParameter.
Expand All @@ -177,6 +225,8 @@ Filler<Dtype>* GetFiller(const FillerParameter& param) {
return new UniformFiller<Dtype>(param);
} else if (type == "xavier") {
return new XavierFiller<Dtype>(param);
} else if (type == "msra") {
return new MSRAFiller<Dtype>(param);
} else {
CHECK(false) << "Unknown filler name: " << param.type();
}
Expand Down
8 changes: 8 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ message FillerParameter {
// The expected number of non-zero output weights for a given input in
// Gaussian filler -- the default -1 means don't perform sparsification.
optional int32 sparse = 7 [default = -1];
// Normalize the filler variance by fan_in, fan_out, or their average.
// Applies to 'xavier' and 'msra' fillers.
enum VarianceNorm {
FAN_IN = 0;
FAN_OUT = 1;
AVERAGE = 2;
}
optional VarianceNorm variance_norm = 8 [default = FAN_IN];
}

message NetParameter {
Expand Down
98 changes: 98 additions & 0 deletions src/caffe/test/test_filler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,102 @@ TYPED_TEST(GaussianFillerTest, TestFill) {
EXPECT_LE(var, target_var * 5.);
}

template <typename Dtype>
class XavierFillerTest : public ::testing::Test {
protected:
XavierFillerTest()
: blob_(new Blob<Dtype>(1000, 2, 4, 5)),
filler_param_() {
}
virtual void test_params(FillerParameter_VarianceNorm variance_norm,
Dtype n) {
this->filler_param_.set_variance_norm(variance_norm);
this->filler_.reset(new XavierFiller<Dtype>(this->filler_param_));
this->filler_->Fill(blob_);
EXPECT_TRUE(this->blob_);
const int count = this->blob_->count();
const Dtype* data = this->blob_->cpu_data();
Dtype mean = 0.;
Dtype ex2 = 0.;
for (int i = 0; i < count; ++i) {
mean += data[i];
ex2 += data[i] * data[i];
}
mean /= count;
ex2 /= count;
Dtype std = sqrt(ex2 - mean*mean);
Dtype target_std = sqrt(2.0 / n);
EXPECT_NEAR(mean, 0.0, 0.1);
EXPECT_NEAR(std, target_std, 0.1);
}
virtual ~XavierFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
FillerParameter filler_param_;
shared_ptr<XavierFiller<Dtype> > filler_;
};

TYPED_TEST_CASE(XavierFillerTest, TestDtypes);

TYPED_TEST(XavierFillerTest, TestFillFanIn) {
TypeParam n = 2*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
}
TYPED_TEST(XavierFillerTest, TestFillFanOut) {
TypeParam n = 1000*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
}
TYPED_TEST(XavierFillerTest, TestFillAverage) {
TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
}

template <typename Dtype>
class MSRAFillerTest : public ::testing::Test {
protected:
MSRAFillerTest()
: blob_(new Blob<Dtype>(1000, 2, 4, 5)),
filler_param_() {
}
virtual void test_params(FillerParameter_VarianceNorm variance_norm,
Dtype n) {
this->filler_param_.set_variance_norm(variance_norm);
this->filler_.reset(new MSRAFiller<Dtype>(this->filler_param_));
this->filler_->Fill(blob_);
EXPECT_TRUE(this->blob_);
const int count = this->blob_->count();
const Dtype* data = this->blob_->cpu_data();
Dtype mean = 0.;
Dtype ex2 = 0.;
for (int i = 0; i < count; ++i) {
mean += data[i];
ex2 += data[i] * data[i];
}
mean /= count;
ex2 /= count;
Dtype std = sqrt(ex2 - mean*mean);
Dtype target_std = sqrt(2.0 / n);
EXPECT_NEAR(mean, 0.0, 0.1);
EXPECT_NEAR(std, target_std, 0.1);
}
virtual ~MSRAFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
FillerParameter filler_param_;
shared_ptr<MSRAFiller<Dtype> > filler_;
};

TYPED_TEST_CASE(MSRAFillerTest, TestDtypes);

TYPED_TEST(MSRAFillerTest, TestFillFanIn) {
TypeParam n = 2*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
}
TYPED_TEST(MSRAFillerTest, TestFillFanOut) {
TypeParam n = 1000*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
}
TYPED_TEST(MSRAFillerTest, TestFillAverage) {
TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
}

} // namespace caffe