Skip to content

Commit 4161734

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents f785ff2 + 2b981cf commit 4161734

File tree

4 files changed

+60
-10
lines changed

4 files changed

+60
-10
lines changed

Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ PROTO_GEN_HEADER := ${PROTO_SRCS:.proto=.pb.h}
4949
PROTO_GEN_CC := ${PROTO_SRCS:.proto=.pb.cc}
5050
PROTO_GEN_PY := ${PROTO_SRCS:.proto=_pb2.py}
5151
# The objects corresponding to the source files
52-
# These objects will be linked into the final shared library, so we
52+
# These objects will be linked into the final shared library, so we
5353
# exclude the test and example objects.
5454
CXX_OBJS := ${CXX_SRCS:.cpp=.o}
5555
CU_OBJS := ${CU_SRCS:.cu=.cuo}
@@ -84,7 +84,7 @@ CXXFLAGS += -pthread -fPIC -O2 $(COMMON_FLAGS)
8484
NVCCFLAGS := -Xcompiler -fPIC -O2 $(COMMON_FLAGS)
8585
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \
8686
$(foreach library,$(LIBRARIES),-l$(library)) \
87-
-Wl,-rpath,../libs/
87+
-Wl,-rpath,../lib/
8888
PYTHON_LDFLAGS := $(LDFLAGS) $(foreach library,$(PYTHON_LIBRARIES),-l$(library))
8989

9090

python/caffe/pycaffe.cpp

+51-5
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,51 @@ struct CaffeNet
9191
}
9292
}
9393

94+
void Backward(list top_diff, list bottom_diff) {
95+
vector<Blob<float>*>& output_blobs = net_->output_blobs();
96+
vector<Blob<float>*>& input_blobs = net_->input_blobs();
97+
CHECK_EQ(len(bottom_diff), input_blobs.size());
98+
CHECK_EQ(len(top_diff), output_blobs.size());
99+
// First, copy the output diff
100+
for (int i = 0; i < output_blobs.size(); ++i) {
101+
object elem = top_diff[i];
102+
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
103+
check_array_against_blob(arr, output_blobs[i]);
104+
switch (Caffe::mode()) {
105+
case Caffe::CPU:
106+
memcpy(output_blobs[i]->mutable_cpu_diff(), PyArray_DATA(arr),
107+
sizeof(float) * output_blobs[i]->count());
108+
break;
109+
case Caffe::GPU:
110+
cudaMemcpy(output_blobs[i]->mutable_gpu_diff(), PyArray_DATA(arr),
111+
sizeof(float) * output_blobs[i]->count(), cudaMemcpyHostToDevice);
112+
break;
113+
default:
114+
LOG(FATAL) << "Unknown Caffe mode.";
115+
} // switch (Caffe::mode())
116+
}
117+
//LOG(INFO) << "Start";
118+
net_->Backward();
119+
//LOG(INFO) << "End";
120+
for (int i = 0; i < input_blobs.size(); ++i) {
121+
object elem = bottom_diff[i];
122+
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
123+
check_array_against_blob(arr, input_blobs[i]);
124+
switch (Caffe::mode()) {
125+
case Caffe::CPU:
126+
memcpy(PyArray_DATA(arr), input_blobs[i]->cpu_diff(),
127+
sizeof(float) * input_blobs[i]->count());
128+
break;
129+
case Caffe::GPU:
130+
cudaMemcpy(PyArray_DATA(arr), input_blobs[i]->gpu_diff(),
131+
sizeof(float) * input_blobs[i]->count(), cudaMemcpyDeviceToHost);
132+
break;
133+
default:
134+
LOG(FATAL) << "Unknown Caffe mode.";
135+
} // switch (Caffe::mode())
136+
}
137+
}
138+
94139
// The caffe::Caffe utility functions.
95140
void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
96141
void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
@@ -108,11 +153,12 @@ BOOST_PYTHON_MODULE(pycaffe)
108153
{
109154
boost::python::class_<CaffeNet>(
110155
"CaffeNet", boost::python::init<string, string>())
111-
.def("Forward", &CaffeNet::Forward)
112-
.def("set_mode_cpu", &CaffeNet::set_mode_cpu)
113-
.def("set_mode_gpu", &CaffeNet::set_mode_gpu)
156+
.def("Forward", &CaffeNet::Forward)
157+
.def("Backward", &CaffeNet::Backward)
158+
.def("set_mode_cpu", &CaffeNet::set_mode_cpu)
159+
.def("set_mode_gpu", &CaffeNet::set_mode_gpu)
114160
.def("set_phase_train", &CaffeNet::set_phase_train)
115-
.def("set_phase_test", &CaffeNet::set_phase_test)
116-
.def("set_device", &CaffeNet::set_device)
161+
.def("set_phase_test", &CaffeNet::set_phase_test)
162+
.def("set_device", &CaffeNet::set_device)
117163
;
118164
}

src/caffe/net.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void Net<Dtype>::Init(const NetParameter& param) {
4747
param.input_dim(i * 4 + 3)));
4848
blobs_.push_back(blob_pointer);
4949
blob_names_.push_back(blob_name);
50-
blob_need_backward_.push_back(false);
50+
blob_need_backward_.push_back(param.force_backward());
5151
net_input_blob_indices_.push_back(i);
5252
net_input_blobs_.push_back(blob_pointer.get());
5353
blob_name_to_idx[blob_name] = i;
@@ -64,7 +64,7 @@ void Net<Dtype>::Init(const NetParameter& param) {
6464
layers_.push_back(shared_ptr<Layer<Dtype> >(GetLayer<Dtype>(layer_param)));
6565
layer_names_.push_back(layer_param.name());
6666
LOG(INFO) << "Creating Layer " << layer_param.name();
67-
bool need_backward = false;
67+
bool need_backward = param.force_backward();
6868
// Figure out this layer's input and output
6969
for (int j = 0; j < layer_connection.bottom_size(); ++j) {
7070
const string& blob_name = layer_connection.bottom(j);
@@ -102,7 +102,7 @@ void Net<Dtype>::Init(const NetParameter& param) {
102102
shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>());
103103
blobs_.push_back(blob_pointer);
104104
blob_names_.push_back(blob_name);
105-
blob_need_backward_.push_back(false);
105+
blob_need_backward_.push_back(param.force_backward());
106106
blob_name_to_idx[blob_name] = blob_names_.size() - 1;
107107
available_blobs.insert(blob_name);
108108
top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get());

src/caffe/proto/caffe.proto

+4
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ message NetParameter {
108108
// values specifying the num, channels, height and width of the input blob.
109109
// Thus, there should be a total of (4 * #input) numbers.
110110
repeated int32 input_dim = 4;
111+
// Whether the network will force every layer to carry out backward operation.
112+
// If set False, then whether to carry out backward is determined
113+
// automatically according to the net structure and learning rates.
114+
optional bool force_backward = 5 [ default = false ];
111115
}
112116

113117
message SolverParameter {

0 commit comments

Comments
 (0)