-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Integrate caffe #1226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate caffe #1226
Changes from all commits
d9a087d
931a87a
bd9e974
d121248
bb4bd02
f88e689
3183845
fbd40a1
557fcda
904474d
00c8046
959c099
601bafa
01321c9
a102b82
20259ae
6b643a3
f5f1ba8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,6 @@ build/ | |
.pydevproject | ||
Makefile | ||
.test_env/ | ||
third_party/ | ||
|
||
*~ | ||
bazel-* | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -278,6 +278,18 @@ class Layer { | |
*/ | ||
const LayerPtr& getPrev(size_t i) { return inputLayers_[i]; } | ||
|
||
/** | ||
* Get the size of inputLayer[i]. | ||
*/ | ||
const LayerConfig& getConfig(size_t i) { return config_; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 参数 |
||
|
||
/** | ||
* Get the config of inputLayer[i]. | ||
*/ | ||
const LayerConfig& getPrevConfig(size_t i) { | ||
return inputLayers_[i]->getConfig(i); | ||
} | ||
|
||
/** | ||
* Get the forward-output value. | ||
*/ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -403,6 +403,50 @@ class Parameter { | |
*/ | ||
typedef std::function<void(const VectorPtr vecs[])> ExecFunc; | ||
void exec(ExecFunc func); | ||
|
||
void resize(int size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这么长的method implementation应该放在.cc文件里,而不是头文件里。 |
||
const std::vector<int>& dim, | ||
MatType matType = MAT_NORMAL) { | ||
CHECK_LE(dim.size(), 2UL) << "parameter only support 2-dimension now."; | ||
int cnt = 1; | ||
// reset ParameterConfig | ||
config_.set_size(size); | ||
config_.clear_dims(); | ||
for (size_t i = 0; i < dim.size(); ++i) { | ||
cnt *= dim[i]; | ||
config_.add_dims(dim[i]); | ||
} | ||
CHECK_EQ(size, cnt); | ||
|
||
// reset PARAMETER_VALUE | ||
auto& valueBuf = getBuf(PARAMETER_VALUE); | ||
valueBuf->resize(size); | ||
if (mats_[PARAMETER_VALUE]) { | ||
mats_[PARAMETER_VALUE] = NULL; | ||
setMat(PARAMETER_VALUE, matType); | ||
} | ||
|
||
// reset PARAMETER_GRADIENT | ||
auto& gradBuf = getBuf(PARAMETER_GRADIENT); | ||
if (gradBuf) { | ||
gradBuf->resize(size); | ||
} | ||
if (mats_[PARAMETER_GRADIENT]) { | ||
mats_[PARAMETER_GRADIENT] = NULL; | ||
setMat(PARAMETER_GRADIENT, matType); | ||
} | ||
|
||
// reset PARAMETER_MOMENTUM | ||
auto& momBuf = getBuf(PARAMETER_MOMENTUM); | ||
if (momBuf) { | ||
momBuf->resize(size); | ||
momBuf->zeroMem(); | ||
} | ||
if (mats_[PARAMETER_MOMENTUM]) { | ||
mats_[PARAMETER_MOMENTUM] = NULL; | ||
setMat(PARAMETER_MOMENTUM, matType); | ||
} | ||
} | ||
}; | ||
|
||
typedef std::map<std::string, ParameterPtr> ParameterMap; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License | ||
# use caffe plugin | ||
|
||
set(CAFFE_PATH $ENV{CAFFE_PATH} CACHE PATH "Folder contains caffe") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里可以假设caffe的源码在某个目录里吗?比如 |
||
if(NOT DEFINED CAFFE_PATH) | ||
message(FATAL_ERROR "Please set CAFFE_PATH to point to the caffe source installation") | ||
endif() | ||
list(APPEND CMAKE_MODULE_PATH ${CAFFE_PATH}/cmake) | ||
include_directories(${CAFFE_PATH}/include) | ||
include_directories(${CAFFE_PATH}/build/src) | ||
include_directories(${CMAKE_BINARY_DIR}/caffe/include) | ||
|
||
set(CAFFE_LINKER_LIBS "") | ||
|
||
find_library(CAFFE_LIBRARY NAMES libcaffe.so # libcaffe.a | ||
PATHS ${CAFFE_PATH}/build/lib | ||
NO_DEFAULT_PATH | ||
DOC "Path to caffe library.") | ||
list(APPEND CAFFE_LINKER_LIBS ${CAFFE_LIBRARY}) | ||
|
||
link_directories(${CAFFE_PATH}/build/lib) | ||
|
||
file(GLOB_RECURSE PLUGINS_HEADERS "${CMAKE_CURRENT_SOURCE_DIR}" "*.h") | ||
file(GLOB_RECURSE PLUGINS_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}" "*.cpp") | ||
|
||
if("${CBLAS_PROVIDER}" STREQUAL "MKL") | ||
add_definitions(-DUSE_MKL) | ||
endif() | ||
|
||
if(CUDNN_FOUND) | ||
add_definitions(-DUSE_CUDNN) | ||
endif() | ||
|
||
if(WITH_GPU) | ||
cuda_add_library(paddle_plugin_caffe ${PLUGINS_SOURCES}) | ||
else() | ||
add_library(paddle_plugin_caffe STATIC | ||
${PLUGINS_SOURCES}) | ||
endif() | ||
target_link_libraries(paddle_plugin_caffe ${CAFFE_LINKER_LIBS}) | ||
add_dependencies(paddle_plugin_caffe gen_proto_cpp) | ||
|
||
add_style_check_target(paddle_plugin_caffe | ||
${PLUGINS_SOURCES}) | ||
add_style_check_target(paddle_plugin_caffe | ||
${PLUGINS_HEADERS}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "CaffeBlob.h" | ||
#include "paddle/parameter/Argument.h" | ||
|
||
#include <caffe/blob.hpp> | ||
#include <caffe/layer.hpp> | ||
|
||
namespace paddle { | ||
|
||
std::vector<int> layerConfigToBlobShape(const int batch, | ||
const LayerConfig& config) { | ||
std::vector<int> shape; | ||
shape.push_back(batch); | ||
int h = config.height(); | ||
int w = config.width(); | ||
int size = config.size(); | ||
if (h > 1 || w > 1) { | ||
int c = size / h / w; | ||
CHECK_EQ(c * h * w, size); | ||
shape.push_back(c); | ||
shape.push_back(h); | ||
shape.push_back(w); | ||
} else { | ||
shape.push_back(size); | ||
} | ||
return shape; | ||
} | ||
|
||
std::vector<int> argShapeToVector(const Argument& arg) { | ||
std::vector<int> shape; | ||
shape.push_back(arg.getBatchSize()); | ||
int frameHeight = arg.getFrameHeight(); | ||
int frameWidth = arg.getFrameWidth(); | ||
int dim = 0; | ||
if (arg.value) { | ||
dim = arg.value->getWidth(); | ||
} else if (arg.grad) { | ||
dim = arg.grad->getWidth(); | ||
} | ||
CHECK(dim); | ||
// Paddle only support 4 dimension at most. | ||
// s1 means channel number for convolution layer, | ||
// means hidden dimension for other layers. | ||
int s1 = dim; | ||
if (frameHeight > 1 || frameWidth > 1) { | ||
s1 = dim / frameHeight / frameWidth; | ||
CHECK(s1); | ||
CHECK_EQ(dim, s1 * frameHeight * frameWidth); | ||
} | ||
shape.push_back(s1); | ||
if (frameHeight) shape.push_back(frameHeight); | ||
if (frameWidth) shape.push_back(frameWidth); | ||
return shape; | ||
} | ||
|
||
void setBlob(MemoryTypes memType, | ||
::caffe::Blob<real>* blob, | ||
real* d, | ||
bool useGpu) { | ||
if (memType == VALUE) { | ||
if (useGpu) { | ||
blob->set_gpu_data(d); | ||
} else { | ||
blob->set_cpu_data(d); | ||
} | ||
} else { | ||
if (useGpu) { | ||
blob->set_gpu_diff(d); | ||
} else { | ||
blob->set_cpu_diff(d); | ||
} | ||
} | ||
} | ||
|
||
void argToBlob(MemoryTypes memType, | ||
const Argument& arg, | ||
::caffe::Blob<real>* blob, | ||
bool useGpu) { | ||
std::vector<int> shape = argShapeToVector(arg); | ||
blob->Reshape(shape); | ||
auto& mat = memType == VALUE ? arg.value : arg.grad; | ||
CHECK(mat); | ||
setBlob(memType, blob, mat->getData(), useGpu); | ||
} | ||
|
||
void blobToArg(MemoryTypes memType, | ||
::caffe::Blob<real>* blob, | ||
Argument& arg, | ||
bool useGpu) { | ||
auto& shape = blob->shape(); | ||
int h = shape[0]; | ||
int w = blob->count(1); | ||
if (shape.size() == 4) { | ||
arg.setFrameHeight(shape[2]); | ||
arg.setFrameWidth(shape[3]); | ||
} | ||
CHECK_LE(shape.size(), 4) << "Now only support 4-dimension at most"; | ||
if (memType == VALUE) { | ||
real* data = useGpu ? blob->mutable_gpu_data() : blob->mutable_cpu_data(); | ||
arg.value = Matrix::create(data, h, w, false, useGpu); | ||
} else { | ||
real* data = useGpu ? blob->mutable_gpu_diff() : blob->mutable_cpu_diff(); | ||
arg.grad = Matrix::create(data, h, w, false, useGpu); | ||
} | ||
} | ||
|
||
void copyBlobToParameter(MemoryTypes memType, | ||
::caffe::Blob<real>* blob, | ||
ParameterPtr para, | ||
bool useGpu) { | ||
int size = blob->count(); | ||
if (memType == VALUE) { | ||
real* d = useGpu ? blob->mutable_gpu_data() : blob->mutable_cpu_data(); | ||
para->getBuf(PARAMETER_VALUE)->copyFrom(d, size); | ||
} else { | ||
real* d = useGpu ? blob->mutable_gpu_diff() : blob->mutable_cpu_diff(); | ||
para->getBuf(PARAMETER_GRADIENT)->copyFrom(d, size); | ||
} | ||
} | ||
|
||
void parameterToBlob(MemoryTypes memType, | ||
ParameterPtr para, | ||
::caffe::Blob<real>* blob, | ||
const std::vector<int>& shape, | ||
bool useGpu) { | ||
blob->Reshape(shape); | ||
auto& buf = memType == VALUE ? para->getBuf(PARAMETER_VALUE) | ||
: para->getBuf(PARAMETER_GRADIENT); | ||
setBlob(memType, blob, buf->getData(), useGpu); | ||
} | ||
|
||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "CaffeUtil.h" | ||
#include "paddle/parameter/Argument.h" | ||
|
||
#include <caffe/blob.hpp> | ||
#include <caffe/layer.hpp> | ||
|
||
namespace paddle { | ||
|
||
std::vector<int> layerConfigToBlobShape(const int batch, | ||
const LayerConfig& preConfig); | ||
std::vector<int> argShapeToVector(const Argument& arg); | ||
|
||
void setBlob(MemoryTypes memType, | ||
::caffe::Blob<real>* blob, | ||
real* d, | ||
bool useGpu); | ||
|
||
void argToBlob(MemoryTypes memType, | ||
const Argument& arg, | ||
::caffe::Blob<real>* blob, | ||
bool useGpu); | ||
|
||
void blobToArg(MemoryTypes memType, | ||
::caffe::Blob<real>* blob, | ||
Argument& arg, | ||
bool useGpu); | ||
|
||
void copyBlobToParameter(MemoryTypes memType, | ||
::caffe::Blob<real>* blob, | ||
ParameterPtr para, | ||
bool useGpu); | ||
|
||
void parameterToBlob(MemoryTypes memType, | ||
ParameterPtr para, | ||
::caffe::Blob<real>* blob, | ||
const std::vector<int>& shape, | ||
bool useGpu); | ||
|
||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么不是放在 third_party 目录里,而是放在一个新的 plugin 目录里呢?这个plugin目录将来还会用来放其他一些什么不属于"3rd party"的内容吗?