Skip to content

Commit

Permalink
[optional] support protobuf (microsoft#908)
Browse files Browse the repository at this point in the history
  • Loading branch information
wxchan authored and guolinke committed Oct 19, 2017
1 parent fa45a97 commit 53b9985
Show file tree
Hide file tree
Showing 21 changed files with 400 additions and 55 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ env:
- TASK=if-else
- TASK=sdist PYTHON_VERSION=3.4
- TASK=bdist PYTHON_VERSION=3.5
- TASK=proto
- TASK=gpu METHOD=source
- TASK=gpu METHOD=pip

Expand All @@ -38,6 +39,8 @@ matrix:
env: TASK=pylint
- os: osx
env: TASK=check-docs
- os: osx
env: TASK=proto

before_install:
- test -n $CC && unset CC
Expand Down
14 changes: 13 additions & 1 deletion .travis/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,24 @@ if [[ ${TASK} == "if-else" ]]; then
conda create -q -n test-env python=$PYTHON_VERSION numpy
source activate test-env
mkdir build && cd build && cmake .. && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf convert_model_language=cpp convert_model=../../src/boosting/gbdt_prediction.cpp && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1
cd $TRAVIS_BUILD_DIR/build && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=predict.conf output_result=ifelse.pred && python test.py || exit -1
exit 0
fi

if [[ ${TASK} == "proto" ]]; then
conda create -q -n test-env python=$PYTHON_VERSION numpy
source activate test-env
mkdir build && cd build && cmake .. && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1
cd $TRAVIS_BUILD_DIR && git clone https://github.com/google/protobuf && cd protobuf && ./autogen.sh && ./configure && make && sudo make install && sudo ldconfig
cd $TRAVIS_BUILD_DIR/build && rm -rf * && cmake -DUSE_PROTO=ON .. && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf model_format=proto && ../../lightgbm config=predict.conf output_result=proto.pred model_format=proto || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && python test.py || exit -1
exit 0
fi

conda create -q -n test-env python=$PYTHON_VERSION numpy nose scipy scikit-learn pandas matplotlib pytest
source activate test-env

Expand Down
20 changes: 18 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,24 @@ file(GLOB SOURCES
src/treelearner/*.cpp
)

add_executable(lightgbm src/main.cpp ${SOURCES})
add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES})
if (USE_PROTO)
find_package(Protobuf REQUIRED)
PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS proto/model.proto)
include_directories(${PROTOBUF_INCLUDE_DIRS})
include_directories(${CMAKE_CURRENT_BINARY_DIR})
SET(PROTO_FILES src/proto/gbdt_model_proto.cpp ${PROTO_HDRS} ${PROTO_SRCS})
else()
include_directories(src/proto/not_implemented)
SET(PROTO_FILES src/proto/not_implemented/gbdt_model_proto.cpp)
endif(USE_PROTO)

add_executable(lightgbm src/main.cpp ${SOURCES} ${PROTO_FILES})
add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES} ${PROTO_FILES})

if (USE_PROTO)
TARGET_LINK_LIBRARIES(lightgbm ${PROTOBUF_LIBRARIES})
TARGET_LINK_LIBRARIES(_lightgbm ${PROTOBUF_LIBRARIES})
endif(USE_PROTO)

if(MSVC)
set_target_properties(_lightgbm PROPERTIES OUTPUT_NAME "lib_lightgbm")
Expand Down
14 changes: 14 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,20 @@ IO Parameters

- file name of prediction result in ``prediction`` task

- ``model_format``, default=\ ``text``, type=string

- format to save and load model.

- ``text``, use text string.

- ``proto``, use protocol buffer binary format.

- save multiple formats by joining them with comma, like ``text,proto``, in this case, ``model_format`` will be add as suffix after ``output_model``.

- not support loading with multiple formats.

- Note: you need to cmake with -DUSE_PROTO=ON to use this parameter.

- ``is_pre_partition``, default=\ ``false``, type=bool

- used for parallel learning (not include feature parallel)
Expand Down
31 changes: 20 additions & 11 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <LightGBM/meta.h>
#include <LightGBM/config.h>
#include "model.pb.h"

#include <vector>
#include <string>
Expand Down Expand Up @@ -166,7 +167,7 @@ class LIGHTGBM_EXPORT Boosting {

/*!
* \brief Save model to file
* \param num_used_model Number of model that want to save, -1 means save all
* \param num_iterations Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not
* \param filename Filename that want to save to
* \return true if succeeded
Expand All @@ -175,7 +176,7 @@ class LIGHTGBM_EXPORT Boosting {

/*!
* \brief Save model to string
* \param num_used_model Number of model that want to save, -1 means save all
* \param num_iterations Number of model that want to save, -1 means save all
* \return Non-empty string if succeeded
*/
virtual std::string SaveModelToString(int num_iterations) const = 0;
Expand All @@ -187,6 +188,20 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual bool LoadModelFromString(const std::string& model_str) = 0;

/*!
* \brief Save model with protobuf
* \param num_iterations Number of model that want to save, -1 means save all
* \param filename Filename that want to save to
*/
virtual void SaveModelToProto(int num_iteration, const char* filename) const = 0;

/*!
* \brief Restore from a serialized protobuf file
* \param filename Filename that want to restore from
* \return true if succeeded
*/
virtual bool LoadModelFromProto(const char* filename) = 0;

/*!
* \brief Calculate feature importances
* \param num_iteration Number of model that want to use for feature importance, -1 means use all
Expand Down Expand Up @@ -251,23 +266,17 @@ class LIGHTGBM_EXPORT Boosting {
/*! \brief Disable copy */
Boosting(const Boosting&) = delete;

static bool LoadFileToBoosting(Boosting* boosting, const char* filename);
static bool LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename);

/*!
* \brief Create boosting object
* \param type Type of boosting
* \param format Format of model
* \param config config for boosting
* \param filename name of model file, if existing will continue to train from this model
* \return The boosting object
*/
static Boosting* CreateBoosting(const std::string& type, const char* filename);

/*!
* \brief Create boosting object from model file
* \param filename name of model file
* \return The boosting object
*/
static Boosting* CreateBoosting(const char* filename);
static Boosting* CreateBoosting(const std::string& type, const std::string& format, const char* filename);

};

Expand Down
3 changes: 2 additions & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ struct IOConfig: public ConfigBase {
std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "gbdt_prediction.cpp";
std::string input_model = "";
std::string model_format = "text";
int verbosity = 1;
int num_iteration_predict = -1;
bool is_pre_partition = false;
Expand Down Expand Up @@ -445,7 +446,7 @@ struct ParameterAlias {
const std::unordered_set<std::string> parameter_set({
"config", "config_file", "task", "device",
"num_threads", "seed", "boosting_type", "objective", "data",
"output_model", "input_model", "output_result", "valid_data",
"output_model", "input_model", "output_result", "model_format", "valid_data",
"is_enable_sparse", "is_pre_partition", "is_training_metric",
"ndcg_eval_at", "min_data_in_leaf", "min_sum_hessian_in_leaf",
"num_leaves", "feature_fraction", "num_iterations",
Expand Down
10 changes: 10 additions & 0 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <LightGBM/meta.h>
#include <LightGBM/dataset.h>
#include "model.pb.h"

#include <string>
#include <vector>
Expand Down Expand Up @@ -31,6 +32,12 @@ class Tree {
*/
explicit Tree(const std::string& str);

/*!
* \brief Construtor, from a protobuf object
* \param model_tree Model protobuf object
*/
explicit Tree(const LightGBM::Model_Tree& model_tree);

~Tree();

/*!
Expand Down Expand Up @@ -165,6 +172,9 @@ class Tree {
/*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index) const;

/*! \brief Serialize this object to protobuf object*/
void ToProto(Model_Tree& model_tree) const;

inline static bool IsZero(double fval) {
if (fval > -kZeroAsMissingValueRange && fval <= kZeroAsMissingValueRange) {
return true;
Expand Down
33 changes: 33 additions & 0 deletions proto/model.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
syntax = "proto3";

package LightGBM;

message Model {
string name = 1;
uint32 num_class = 2;
uint32 num_tree_per_iteration = 3;
uint32 label_index = 4;
uint32 max_feature_idx = 5;
string objective = 6;
bool average_output = 7;
repeated string feature_names = 8;
repeated string feature_infos = 9;
message Tree {
uint32 num_leaves = 1;
uint32 num_cat = 2;
repeated uint32 split_feature = 3;
repeated double split_gain = 4;
repeated double threshold = 5;
repeated uint32 decision_type = 6;
repeated sint32 left_child = 7;
repeated sint32 right_child = 8;
repeated double leaf_value = 9;
repeated uint32 leaf_count = 10;
repeated double internal_value = 11;
repeated double internal_count = 12;
repeated sint32 cat_boundaries = 13;
repeated uint32 cat_threshold = 14;
double shrinkage = 15;
}
repeated Tree trees = 10;
}
21 changes: 20 additions & 1 deletion src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ void Application::InitTrain() {
// create boosting
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.model_format.c_str(),
config_.io_config.input_model.c_str()));
// create objective function
objective_fun_.reset(
Expand All @@ -203,6 +204,22 @@ void Application::InitTrain() {
void Application::Train() {
Log::Info("Started training...");
boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model);
std::vector<std::string> model_formats = Common::Split(config_.io_config.model_format.c_str(), ',');
bool save_with_multiple_format = (model_formats.size() > 1);
for (auto model_format: model_formats) {
std::string save_file_name = config_.io_config.output_model;
if (save_with_multiple_format) {
// use suffix to distinguish different model format
save_file_name += "." + model_format;
}
if (model_format == std::string("text")) {
boosting_->SaveModelToFile(-1, save_file_name.c_str());
} else if (model_format == std::string("proto")) {
boosting_->SaveModelToProto(-1, save_file_name.c_str());
} else {
Log::Fatal("Unknown model format during saving: %s", model_format.c_str());
}
}
// convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
Expand All @@ -223,13 +240,15 @@ void Application::Predict() {

void Application::InitPredict() {
boosting_.reset(
Boosting::CreateBoosting(config_.io_config.input_model.c_str()));
Boosting::CreateBoosting("gbdt", config_.io_config.model_format.c_str(),
config_.io_config.input_model.c_str()));
Log::Info("Finished initializing prediction");
}

void Application::ConvertModel() {
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.model_format.c_str(),
config_.io_config.input_model.c_str()));
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}
Expand Down
46 changes: 21 additions & 25 deletions src/boosting/boosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,30 @@ std::string GetBoostingTypeFromModelFile(const char* filename) {
return type;
}

bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
bool Boosting::LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename) {
if (boosting != nullptr) {
TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines();
std::stringstream str_buf;
for (auto& line : model_reader.Lines()) {
str_buf << line << '\n';
if (format == std::string("text")) {
TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines();
std::stringstream str_buf;
for (auto& line : model_reader.Lines()) {
str_buf << line << '\n';
}
if (!boosting->LoadModelFromString(str_buf.str())) {
return false;
}
} else if (format == std::string("proto")) {
if (!boosting->LoadModelFromProto(filename)) {
return false;
}
} else {
Log::Fatal("Unknown model format during loading: %s", format.c_str());
}
if (!boosting->LoadModelFromString(str_buf.str()))
return false;
}
return true;
}

Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& format, const char* filename) {
if (filename == nullptr || filename[0] == '\0') {
if (type == std::string("gbdt")) {
return new GBDT();
Expand All @@ -41,8 +50,7 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename
}
} else {
std::unique_ptr<Boosting> ret;
auto type_in_file = GetBoostingTypeFromModelFile(filename);
if (type_in_file == std::string("tree")) {
if (format == std::string("proto") || GetBoostingTypeFromModelFile(filename) == std::string("tree")) {
if (type == std::string("gbdt")) {
ret.reset(new GBDT());
} else if (type == std::string("dart")) {
Expand All @@ -54,24 +62,12 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename
} else {
Log::Fatal("unknown boosting type %s", type.c_str());
}
LoadFileToBoosting(ret.get(), filename);
LoadFileToBoosting(ret.get(), format, filename);
} else {
Log::Fatal("unknown submodel type in model file %s", filename);
Log::Fatal("unknown model format or submodel type in model file %s", filename);
}
return ret.release();
}
}

Boosting* Boosting::CreateBoosting(const char* filename) {
auto type = GetBoostingTypeFromModelFile(filename);
std::unique_ptr<Boosting> ret;
if (type == std::string("tree")) {
ret.reset(new GBDT());
} else {
Log::Fatal("unknown submodel type in model file %s", filename);
}
LoadFileToBoosting(ret.get(), filename);
return ret.release();
}

} // namespace LightGBM
1 change: 0 additions & 1 deletion src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
SaveModelToFile(-1, snapshot_out.c_str());
}
}
SaveModelToFile(-1, model_output_path.c_str());
}

double GBDT::BoostFromAverage() {
Expand Down
Loading

0 comments on commit 53b9985

Please sign in to comment.