Skip to content

Commit

Permalink
test net reshaping
Browse files Browse the repository at this point in the history
  • Loading branch information
longjon authored and shelhamer committed Sep 18, 2014
1 parent 4f1b668 commit db5bb15
Showing 1 changed file with 121 additions and 0 deletions.
121 changes: 121 additions & 0 deletions src/caffe/test/test_net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "gtest/gtest.h"

#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/net.hpp"
#include "caffe/util/math_functions.hpp"

Expand Down Expand Up @@ -533,6 +534,68 @@ class NetTest : public MultiDeviceTest<TypeParam> {
InitNetFromProtoString(proto);
}

virtual void InitReshapableNet() {
const string& proto =
"name: 'ReshapableNetwork' "
"input: 'data' "
"input_dim: 1 "
"input_dim: 3 "
"input_dim: 100 "
"input_dim: 100 "
"layers: { "
" name: 'conv1' "
" type: CONVOLUTION "
" bottom: 'data' "
" top: 'conv1' "
" convolution_param { "
" num_output: 5 "
" kernel_size: 3 "
" stride: 2 "
" weight_filler { "
" type: 'gaussian' "
" std: 0.01 "
" } "
" bias_filler { "
" type: 'constant' "
" value: 0.2 "
" } "
" } "
"} "
"layers: { "
" name: 'relu1' "
" type: RELU "
" bottom: 'conv1' "
" top: 'conv1' "
"} "
"layers: { "
" name: 'pool1' "
" type: POOLING "
" bottom: 'conv1' "
" top: 'pool1' "
" pooling_param { "
" pool: MAX "
" kernel_size: 2 "
" stride: 2 "
" } "
"} "
"layers: { "
" name: 'norm1' "
" type: LRN "
" bottom: 'pool1' "
" top: 'norm1' "
" lrn_param { "
" local_size: 3 "
" } "
"} "
"layers: { "
" name: 'softmax' "
" type: SOFTMAX "
" bottom: 'norm1' "
" top: 'softmax' "
"} ";
InitNetFromProtoString(proto);
}

int seed_;
shared_ptr<Net<Dtype> > net_;
};
Expand Down Expand Up @@ -2028,4 +2091,62 @@ TEST_F(FilterNetTest, TestFilterInOutByExcludeMultiRule) {
this->RunFilterNetTest(input_proto_test, output_proto_test);
}

TYPED_TEST(NetTest, TestReshape) {
typedef typename TypeParam::Dtype Dtype;
// We set up bottom blobs of two different sizes, switch between
// them, and check that forward and backward both run and the results
// are the same.
Caffe::set_random_seed(this->seed_);
Caffe::set_mode(Caffe::CPU);
FillerParameter filler_param;
filler_param.set_std(1);
GaussianFiller<Dtype> filler(filler_param);
Blob<Dtype> blob1(4, 3, 9, 11);
Blob<Dtype> blob2(2, 3, 12, 10);
filler.Fill(&blob1);
filler.Fill(&blob2);

this->InitReshapableNet();
Blob<Dtype>* input_blob = this->net_->input_blobs()[0];
Blob<Dtype>* output_blob = this->net_->output_blobs()[0];
input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(),
blob1.width());
caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data());
this->net_->ForwardPrefilled();
// call backward just to make sure it runs
this->net_->Backward();
Blob<Dtype> output1(output_blob->num(), output_blob->channels(),
output_blob->height(), output_blob->width());
caffe_copy(output1.count(), output_blob->cpu_data(),
output1.mutable_cpu_data());

input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(),
blob2.width());
caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data());
this->net_->ForwardPrefilled();
this->net_->Backward();
Blob<Dtype> output2(output_blob->num(), output_blob->channels(),
output_blob->height(), output_blob->width());
caffe_copy(output2.count(), output_blob->cpu_data(),
output2.mutable_cpu_data());

input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(),
blob1.width());
caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data());
this->net_->ForwardPrefilled();
this->net_->Backward();
for (int i = 0; i < output1.count(); ++i) {
CHECK_EQ(*(output1.cpu_data() + i), *(output_blob->cpu_data() + i));
}

input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(),
blob2.width());
caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data());
this->net_->ForwardPrefilled();
this->net_->Backward();
for (int i = 0; i < output2.count(); ++i) {
CHECK_EQ(*(output2.cpu_data() + i), *(output_blob->cpu_data() + i));
}
}

} // namespace caffe

0 comments on commit db5bb15

Please sign in to comment.