Skip to content

Commit aa8e911

Browse files
Taco-Wtqchen
authored andcommitted
[caffe-plugin] Fix Tshape/TBlob. Remove setblob. Update config.mk. (apache#2941)
1 parent 4ed3f3b commit aa8e911

File tree

4 files changed

+30
-23
lines changed

4 files changed

+30
-23
lines changed

make/config.mk

+3-2
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ EXTRA_OPERATORS =
111111
# plugins
112112
#----------------------------
113113

114-
# whether to use caffe integration. This requires including caffe submodule.
115-
# CAFFE_PATH = caffe-lite
114+
# whether to use caffe integration. This requires installing caffe.
115+
# You also need to add CAFFE_PATH/build/lib to your LD_LIBRARY_PATH
116+
# CAFFE_PATH = $(HOME)/caffe
116117
# MXNET_PLUGINS += plugin/caffe/caffe.mk
117118

118119
# whether to use torch integration. This requires installing torch.

plugin/caffe/caffe_blob.cc

+10-12
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace caffe {
1212
template<>
1313
void SetDataGradToBlob<mshadow::cpu, float>(caffeMemoryTypes memType,
1414
std::vector<::caffe::Blob<float>*>::iterator blob,
15-
std::vector<mshadow::TBlob>::const_iterator itr) {
15+
std::vector<TBlob>::const_iterator itr) {
1616
float *data_ptr = reinterpret_cast<float*>((*itr).dptr_);
1717
if (memType == Data)
1818
(*blob)->set_cpu_data(data_ptr);
@@ -23,7 +23,7 @@ void SetDataGradToBlob<mshadow::cpu, float>(caffeMemoryTypes memType,
2323
template<>
2424
void SetDataGradToBlob<mshadow::cpu, double>(caffeMemoryTypes memType,
2525
std::vector<::caffe::Blob<double>*>::iterator blob,
26-
std::vector<mshadow::TBlob>::const_iterator itr) {
26+
std::vector<TBlob>::const_iterator itr) {
2727
double *data_ptr = reinterpret_cast<double*>((*itr).dptr_);
2828
if (memType == Data)
2929
(*blob)->set_cpu_data(data_ptr);
@@ -34,7 +34,7 @@ void SetDataGradToBlob<mshadow::cpu, double>(caffeMemoryTypes memType,
3434
template<>
3535
void SetDataGradToBlob<mshadow::gpu, float>(caffeMemoryTypes memType,
3636
std::vector<::caffe::Blob<float>*>::iterator blob,
37-
std::vector<mshadow::TBlob>::const_iterator itr) {
37+
std::vector<TBlob>::const_iterator itr) {
3838
float *data_ptr = reinterpret_cast<float*>((*itr).dptr_);
3939
if (memType == Data)
4040
(*blob)->set_gpu_data(data_ptr);
@@ -45,27 +45,25 @@ void SetDataGradToBlob<mshadow::gpu, float>(caffeMemoryTypes memType,
4545
template<>
4646
void SetDataGradToBlob<mshadow::gpu, double>(caffeMemoryTypes memType,
4747
std::vector<::caffe::Blob<double>*>::iterator blob,
48-
std::vector<mshadow::TBlob>::const_iterator itr) {
48+
std::vector<TBlob>::const_iterator itr) {
4949
double *data_ptr = reinterpret_cast<double*>((*itr).dptr_);
5050
if (memType == Data)
5151
(*blob)->set_gpu_data(data_ptr);
5252
else
5353
(*blob)->set_gpu_diff(data_ptr);
5454
}
5555

56-
mshadow::TShape Vector2TShape(const std::vector<int> &vec_int) {
57-
mshadow::TShape res;
58-
std::vector<mshadow::index_t> vec_indx;
56+
TShape Vector2TShape(const std::vector<int> &vec_int) {
57+
std::vector<mshadow::index_t> vec;
5958
for (int i = 0; i < vec_int.size(); ++i)
60-
vec_indx.push_back(vec_int[i]);
59+
vec.push_back(vec_int[i]);
6160
// 0-dim represents scalar in caffe
6261
if (vec_int.size() == 0)
63-
vec_indx.push_back(1);
64-
res = vec_indx;
65-
return res;
62+
vec.push_back(1);
63+
return {vec.begin(), vec.end()};
6664
}
6765

68-
std::vector<int> TShape2Vector(const mshadow::TShape &tshape) {
66+
std::vector<int> TShape2Vector(const TShape &tshape) {
6967
std::vector<int> s;
7068
for (int i =0 ; i < tshape.ndim(); ++i)
7169
s.push_back(tshape[i]);

plugin/caffe/caffe_blob.h

+16-8
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
#ifndef PLUGIN_CAFFE_CAFFE_BLOB_H_
88
#define PLUGIN_CAFFE_CAFFE_BLOB_H_
99

10-
#include<mshadow/tensor.h>
11-
#include<mshadow/tensor_blob.h>
12-
#include<vector>
13-
#include<caffe/blob.hpp>
10+
#include <mxnet/tensor_blob.h>
11+
#include <vector>
12+
#include <caffe/blob.hpp>
13+
#include <caffe/layer.hpp>
1414

1515
namespace mxnet {
1616
namespace op {
@@ -20,14 +20,14 @@ namespace caffe {
2020
// Declare Memory Type for Caffe blob
2121
enum caffeMemoryTypes {Data, Grad, Non};
2222

23-
mshadow::TShape Vector2TShape(const std::vector<int> &vec_int);
24-
std::vector<int> TShape2Vector(const mshadow::TShape &tshape);
23+
TShape Vector2TShape(const std::vector<int> &vec_int);
24+
std::vector<int> TShape2Vector(const TShape &tshape);
2525

2626
// implementation of tensor to blob, called by TensorToBlob
2727
template<typename Device, typename Dtype>
2828
void SetDataGradToBlob(caffeMemoryTypes memType,
2929
typename std::vector< ::caffe::Blob<Dtype>*>::iterator blob,
30-
typename std::vector<mshadow::TBlob>::const_iterator itr);
30+
typename std::vector<TBlob>::const_iterator itr);
3131

3232
/**
3333
* \brief The interface to convert mxnet's tensor to caffe's blob
@@ -36,14 +36,22 @@ void SetDataGradToBlob(caffeMemoryTypes memType,
3636
template<typename Device, typename Dtype>
3737
void TBlob2CaffeBlob(caffeMemoryTypes memType,
3838
typename std::vector< ::caffe::Blob<Dtype>*>::iterator blob,
39-
typename std::vector<mshadow::TBlob>::const_iterator tblob,
39+
typename std::vector<TBlob>::const_iterator tblob,
4040
int n = 1) {
4141
for (int i = 0; i < n; ++i, ++blob, ++tblob) {
4242
(*blob)->Reshape(TShape2Vector((*tblob).shape_));
4343
SetDataGradToBlob<Device, Dtype>(memType, blob, tblob);
4444
}
4545
}
4646

47+
template<typename Dtype>
48+
void SetOpBlobs(::caffe::Layer<Dtype> *caffeOp,
49+
const std::vector< ::caffe::Blob<Dtype>*>& weights) {
50+
CHECK_EQ(caffeOp->blobs().size(), weights.size());
51+
for (int i = 0; i < weights.size(); ++i)
52+
caffeOp->blobs()[i].reset(weights[i]);
53+
}
54+
4755
} // namespace caffe
4856
} // namespace op
4957
} // namespace mxnet

plugin/caffe/caffe_op-inl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class CaffeOp : public Operator {
110110
wei_.begin(),
111111
in_data.begin() + param_.num_data,
112112
param_.num_weight);
113-
caffeOp_->SetBlobs(wei_);
113+
caffe::SetOpBlobs(caffeOp_, wei_);
114114
}
115115

116116
caffeOp_->Forward(bot_, top_);

0 commit comments

Comments
 (0)