@@ -91,6 +91,51 @@ struct CaffeNet
91
91
}
92
92
}
93
93
94
+ void Backward (list top_diff, list bottom_diff) {
95
+ vector<Blob<float >*>& output_blobs = net_->output_blobs ();
96
+ vector<Blob<float >*>& input_blobs = net_->input_blobs ();
97
+ CHECK_EQ (len (bottom_diff), input_blobs.size ());
98
+ CHECK_EQ (len (top_diff), output_blobs.size ());
99
+ // First, copy the output diff
100
+ for (int i = 0 ; i < output_blobs.size (); ++i) {
101
+ object elem = top_diff[i];
102
+ PyArrayObject* arr = reinterpret_cast <PyArrayObject*>(elem.ptr ());
103
+ check_array_against_blob (arr, output_blobs[i]);
104
+ switch (Caffe::mode ()) {
105
+ case Caffe::CPU:
106
+ memcpy (output_blobs[i]->mutable_cpu_diff (), PyArray_DATA (arr),
107
+ sizeof (float ) * output_blobs[i]->count ());
108
+ break ;
109
+ case Caffe::GPU:
110
+ cudaMemcpy (output_blobs[i]->mutable_gpu_diff (), PyArray_DATA (arr),
111
+ sizeof (float ) * output_blobs[i]->count (), cudaMemcpyHostToDevice);
112
+ break ;
113
+ default :
114
+ LOG (FATAL) << " Unknown Caffe mode." ;
115
+ } // switch (Caffe::mode())
116
+ }
117
+ // LOG(INFO) << "Start";
118
+ net_->Backward ();
119
+ // LOG(INFO) << "End";
120
+ for (int i = 0 ; i < input_blobs.size (); ++i) {
121
+ object elem = bottom_diff[i];
122
+ PyArrayObject* arr = reinterpret_cast <PyArrayObject*>(elem.ptr ());
123
+ check_array_against_blob (arr, input_blobs[i]);
124
+ switch (Caffe::mode ()) {
125
+ case Caffe::CPU:
126
+ memcpy (PyArray_DATA (arr), input_blobs[i]->cpu_diff (),
127
+ sizeof (float ) * input_blobs[i]->count ());
128
+ break ;
129
+ case Caffe::GPU:
130
+ cudaMemcpy (PyArray_DATA (arr), input_blobs[i]->gpu_diff (),
131
+ sizeof (float ) * input_blobs[i]->count (), cudaMemcpyDeviceToHost);
132
+ break ;
133
+ default :
134
+ LOG (FATAL) << " Unknown Caffe mode." ;
135
+ } // switch (Caffe::mode())
136
+ }
137
+ }
138
+
94
139
// The caffe::Caffe utility functions.
95
140
void set_mode_cpu () { Caffe::set_mode (Caffe::CPU); }
96
141
void set_mode_gpu () { Caffe::set_mode (Caffe::GPU); }
@@ -108,11 +153,12 @@ BOOST_PYTHON_MODULE(pycaffe)
108
153
{
109
154
boost::python::class_<CaffeNet>(
110
155
" CaffeNet" , boost::python::init<string, string>())
111
- .def (" Forward" , &CaffeNet::Forward)
112
- .def (" set_mode_cpu" , &CaffeNet::set_mode_cpu)
113
- .def (" set_mode_gpu" , &CaffeNet::set_mode_gpu)
156
+ .def (" Forward" , &CaffeNet::Forward)
157
+ .def (" Backward" , &CaffeNet::Backward)
158
+ .def (" set_mode_cpu" , &CaffeNet::set_mode_cpu)
159
+ .def (" set_mode_gpu" , &CaffeNet::set_mode_gpu)
114
160
.def (" set_phase_train" , &CaffeNet::set_phase_train)
115
- .def (" set_phase_test" , &CaffeNet::set_phase_test)
116
- .def (" set_device" , &CaffeNet::set_device)
161
+ .def (" set_phase_test" , &CaffeNet::set_phase_test)
162
+ .def (" set_device" , &CaffeNet::set_device)
117
163
;
118
164
}
0 commit comments