diff --git a/python/caffe/pycaffe.cpp b/python/caffe/pycaffe.cpp index d09ccf55b7f..2bfae9e2739 100644 --- a/python/caffe/pycaffe.cpp +++ b/python/caffe/pycaffe.cpp @@ -6,6 +6,7 @@ #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #include +#include #include #include "caffe/caffe.hpp" @@ -20,6 +21,78 @@ using boost::python::extract; using boost::python::len; using boost::python::list; using boost::python::object; +using boost::python::handle; +using boost::python::vector_indexing_suite; + + +// wrap shared_ptr > in a class that we construct in C++ and pass +// to Python +class CaffeBlob { + public: + + CaffeBlob(const shared_ptr > &blob) + : blob_(blob) {} + + CaffeBlob() + {} + + int num() const { return blob_->num(); } + int channels() const { return blob_->channels(); } + int height() const { return blob_->height(); } + int width() const { return blob_->width(); } + int count() const { return blob_->count(); } + + bool operator == (const CaffeBlob &other) + { + return this->blob_ == other.blob_; + } + + protected: + shared_ptr > blob_; +}; + + +// we need another wrapper (used as boost::python's HeldType) that receives a +// self PyObject * which we can use as ndarray.base, so that data/diff memory +// is not freed while still being used in Python +class CaffeBlobWrap : public CaffeBlob { + public: + CaffeBlobWrap(PyObject *p, shared_ptr > &blob) + : CaffeBlob(blob), self_(p) {} + + CaffeBlobWrap(PyObject *p, const CaffeBlob &blob) + : CaffeBlob(blob), self_(p) {} + + object get_data() + { + npy_intp dims[] = {num(), channels(), height(), width()}; + + PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32, + blob_->mutable_cpu_data()); + PyArray_SetBaseObject(reinterpret_cast(obj), self_); + Py_INCREF(self_); + handle<> h(obj); + + return object(h); + } + + object get_diff() + { + npy_intp dims[] = {num(), channels(), height(), width()}; + + PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32, + blob_->mutable_cpu_diff()); + PyArray_SetBaseObject(reinterpret_cast(obj), self_); + Py_INCREF(self_); + handle<> h(obj); + + return object(h); + } + + private: + PyObject *self_; +}; + // A simple wrapper over CaffeNet that runs the forward process. @@ -143,14 +216,24 @@ struct CaffeNet void set_phase_test() { Caffe::set_phase(Caffe::TEST); } void set_device(int device_id) { Caffe::SetDevice(device_id); } + vector blobs() { + return vector(net_->blobs().begin(), net_->blobs().end()); + } + + vector params() { + return vector(net_->params().begin(), net_->params().end()); + } + // The pointer to the internal caffe::Net instant. shared_ptr > net_; }; + // The boost python module definition. BOOST_PYTHON_MODULE(pycaffe) { + boost::python::class_( "CaffeNet", boost::python::init()) .def("Forward", &CaffeNet::Forward) @@ -160,5 +243,24 @@ BOOST_PYTHON_MODULE(pycaffe) .def("set_phase_train", &CaffeNet::set_phase_train) .def("set_phase_test", &CaffeNet::set_phase_test) .def("set_device", &CaffeNet::set_device) + .def("blobs", &CaffeNet::blobs) + .def("params", &CaffeNet::params) + ; + + boost::python::class_( + "CaffeBlob", boost::python::no_init) + .add_property("num", &CaffeBlob::num) + .add_property("channels", &CaffeBlob::channels) + .add_property("height", &CaffeBlob::height) + .add_property("width", &CaffeBlob::width) + .add_property("count", &CaffeBlob::count) + .add_property("data", &CaffeBlobWrap::get_data) + .add_property("diff", &CaffeBlobWrap::get_diff) ; + + boost::python::class_ >("BlobVec") + .def(vector_indexing_suite, true>()); + + import_array(); + }