Skip to content

Commit

Permalink
small code and docs refactoring (#3681)
Browse files Browse the repository at this point in the history
* small code and docs refactoring

* Update CMakeLists.txt

* Update .vsts-ci.yml

* Update test.sh

* continue

* continue

* revert stable sort for all-unique values
  • Loading branch information
StrikerRUS authored Dec 28, 2020
1 parent be1202d commit 5a46084
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 64 deletions.
9 changes: 3 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,6 @@ endif(USE_CUDA)
if(USE_OPENMP)
find_package(OpenMP REQUIRED)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
else()
# Ignore unknown #pragma warning
if((CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
OR (CMAKE_CXX_COMPILER_ID STREQUAL "GNU"))
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unknown-pragmas")
endif()
endif(USE_OPENMP)

if(USE_GPU)
Expand Down Expand Up @@ -272,6 +266,9 @@ if(UNIX OR MINGW OR CYGWIN)
if(USE_SWIG)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-strict-aliasing")
endif()
if(NOT USE_OPENMP)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unknown-pragmas -Wno-unused-private-field")
endif()
endif()

if(WIN32 AND MINGW)
Expand Down
16 changes: 9 additions & 7 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,25 @@ Core Parameters

- ``linear_tree`` :raw-html:`<a id="linear_tree" title="Permalink to this parameter" href="#linear_tree">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool

- fit piecewise linear gradient boosting tree, only works with cpu and serial tree learner
- fit piecewise linear gradient boosting tree

- tree splits are chosen in the usual way, but the model at each leaf is linear instead of constant

- the linear model at each leaf includes all the numerical features in that leaf's branch

- categorical features are used for splits as normal but are not used in the linear models

- missing values must be encoded as ``np.nan`` (Python) or ``NA`` (cli), not ``0``
- missing values must be encoded as ``np.nan`` (Python) or ``NA`` (CLI), not ``0``

- it is recommended to rescale data before training so that features have similar mean and standard deviation

- not yet supported in R-package
- **Note**: only works with CPU and ``serial`` tree learner

- ``regression_l1`` objective is not supported with linear tree boosting
- **Note**: not yet supported in R-package

- setting ``linear_tree = True`` significantly increases the memory use of LightGBM
- **Note**: ``regression_l1`` objective is not supported with linear tree boosting

- **Note**: setting ``linear_tree=true`` significantly increases the memory use of LightGBM

- ``data`` :raw-html:`<a id="data" title="Permalink to this parameter" href="#data">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string, aliases: ``train``, ``train_data``, ``train_data_file``, ``data_filename``

Expand Down Expand Up @@ -406,7 +408,7 @@ Learning Control Parameters

- ``linear_lambda`` :raw-html:`<a id="linear_lambda" title="Permalink to this parameter" href="#linear_lambda">&#x1F517;&#xFE0E;</a>`, default = ``0.0``, type = double, constraints: ``linear_lambda >= 0.0``

- Linear tree regularisation, the parameter `lambda` in Eq 3 of <https://arxiv.org/pdf/1802.05640.pdf>
- linear tree regularization, corresponds to the parameter ``lambda`` in Eq. 3 of `Gradient Boosting with Piece-Wise Linear Regression Trees <https://arxiv.org/pdf/1802.05640.pdf>`__

- ``min_gain_to_split`` :raw-html:`<a id="min_gain_to_split" title="Permalink to this parameter" href="#min_gain_to_split">&#x1F517;&#xFE0E;</a>`, default = ``0.0``, type = double, aliases: ``min_split_gain``, constraints: ``min_gain_to_split >= 0.0``

Expand Down Expand Up @@ -580,7 +582,7 @@ Learning Control Parameters

- if ``path_smooth > 0`` then ``min_data_in_leaf`` must be at least ``2``

- larger values give stronger regularisation
- larger values give stronger regularization

- the weight of each node is ``(n / path_smooth) * w + w_p / (n / path_smooth + 1)``, where ``n`` is the number of samples in the node, ``w`` is the optimal node weight to minimise the loss (approximately ``-sum_gradients / sum_hessians``), and ``w_p`` is the weight of the parent node

Expand Down
16 changes: 8 additions & 8 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,6 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
int* out);

/*!
* \brief Get boolean representing whether booster is fitting linear trees.
* \param handle Handle of dataset
* \param[out] out The address to hold linear indicator
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out);

/*!
* \brief Add features from ``source`` to ``target``.
* \param target The handle of the dataset to add features to
Expand All @@ -408,6 +400,14 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,

// --- start Booster interfaces

/*!
* \brief Get boolean representing whether booster is fitting linear trees.
* \param handle Handle of booster
* \param[out] out The address to hold linear trees indicator
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLinear(BoosterHandle handle, bool* out);

/*!
* \brief Create a new boosting learner.
* \param train_data Training dataset
Expand Down
15 changes: 8 additions & 7 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,16 @@ struct Config {
// descl2 = **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations
std::string boosting = "gbdt";

// desc = fit piecewise linear gradient boosting tree, only works with cpu and serial tree learner
// desc = fit piecewise linear gradient boosting tree
// descl2 = tree splits are chosen in the usual way, but the model at each leaf is linear instead of constant
// descl2 = the linear model at each leaf includes all the numerical features in that leaf's branch
// descl2 = categorical features are used for splits as normal but are not used in the linear models
// descl2 = missing values must be encoded as ``np.nan`` (Python) or ``NA`` (cli), not ``0``
// descl2 = missing values must be encoded as ``np.nan`` (Python) or ``NA`` (CLI), not ``0``
// descl2 = it is recommended to rescale data before training so that features have similar mean and standard deviation
// descl2 = not yet supported in R-package
// descl2 = ``regression_l1`` objective is not supported with linear tree boosting
// descl2 = setting ``linear_tree = True`` significantly increases the memory use of LightGBM
// descl2 = **Note**: only works with CPU and ``serial`` tree learner
// descl2 = **Note**: not yet supported in R-package
// descl2 = **Note**: ``regression_l1`` objective is not supported with linear tree boosting
// descl2 = **Note**: setting ``linear_tree=true`` significantly increases the memory use of LightGBM
bool linear_tree = false;

// alias = train, train_data, train_data_file, data_filename
Expand Down Expand Up @@ -378,7 +379,7 @@ struct Config {
double lambda_l2 = 0.0;

// check = >=0.0
// desc = Linear tree regularisation, the parameter `lambda` in Eq 3 of <https://arxiv.org/pdf/1802.05640.pdf>
// desc = linear tree regularization, corresponds to the parameter ``lambda`` in Eq. 3 of `Gradient Boosting with Piece-Wise Linear Regression Trees <https://arxiv.org/pdf/1802.05640.pdf>`__
double linear_lambda = 0.0;

// alias = min_split_gain
Expand Down Expand Up @@ -530,7 +531,7 @@ struct Config {
// desc = helps prevent overfitting on leaves with few samples
// desc = if set to zero, no smoothing is applied
// desc = if ``path_smooth > 0`` then ``min_data_in_leaf`` must be at least ``2``
// desc = larger values give stronger regularisation
// desc = larger values give stronger regularization
// descl2 = the weight of each node is ``(n / path_smooth) * w + w_p / (n / path_smooth + 1)``, where ``n`` is the number of samples in the node, ``w`` is the optimal node weight to minimise the loss (approximately ``-sum_gradients / sum_hessians``), and ``w_p`` is the weight of the parent node
// descl2 = note that the parent output ``w_p`` itself has smoothing applied, unless it is the root node, so that the smoothing effect accumulates with the tree depth
double path_smooth = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class Booster {
"the `min_data_in_leaf`.");
}
if (new_param.count("linear_tree") && (new_config.linear_tree != old_config.linear_tree)) {
Log:: Fatal("Cannot change between gbdt_linear boosting and other boosting types after Dataset handle has been constructed.");
Log::Fatal("Cannot change linear_tree after constructed Dataset handle.");
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@ void Config::CheckParamConflict() {
Log::Warning("CUDA currently requires double precision calculations.");
gpu_use_dp = true;
}
// linear tree learner must be serial type and cpu device
// linear tree learner must be serial type and run on cpu device
if (linear_tree) {
if (device_type == std::string("gpu")) {
if (device_type != std::string("cpu")) {
device_type = "cpu";
Log::Warning("Linear tree learner only works with CPU.");
}
Expand Down
1 change: 0 additions & 1 deletion src/io/dataset_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,6 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
}
mem_ptr = buffer.data();
const float* tmp_ptr_raw_row = reinterpret_cast<const float*>(mem_ptr);
std::vector<float> curr_row(dataset->num_numeric_features_, 0);
for (int j = 0; j < dataset->num_features(); ++j) {
int feat_ind = dataset->numeric_feature_map_[j];
if (feat_ind >= 0) {
Expand Down
8 changes: 6 additions & 2 deletions src/io/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,9 @@ Tree::Tree(const char* str, size_t* used_len) {
is_linear_ = static_cast<bool>(is_linear_int);
}

if ((num_leaves_ <= 1) && !is_linear_) { return; }
if ((num_leaves_ <= 1) && !is_linear_) {
return;
}

if (key_vals.count("left_child")) {
left_child_ = CommonC::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Expand Down Expand Up @@ -780,7 +782,9 @@ Tree::Tree(const char* str, size_t* used_len) {
leaf_features_inner_.resize(num_leaves_);
if (num_feat.size() > 0) {
int total_num_feat = 0;
for (size_t i = 0; i < num_feat.size(); ++i) { total_num_feat += num_feat[i]; }
for (size_t i = 0; i < num_feat.size(); ++i) {
total_num_feat += num_feat[i];
}
std::vector<int> all_leaf_features;
if (key_vals.count("leaf_features")) {
all_leaf_features = Common::StringToArrayFast<int>(key_vals["leaf_features"], total_num_feat);
Expand Down
2 changes: 1 addition & 1 deletion src/treelearner/linear_tree_learner.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Copyright (c) 2020 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#include "linear_tree_learner.h"
Expand Down
2 changes: 1 addition & 1 deletion src/treelearner/linear_tree_learner.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Copyright (c) 2020 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_TREELEARNER_LINEAR_TREE_LEARNER_H_
Expand Down
2 changes: 1 addition & 1 deletion src/treelearner/tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

#include "cuda_tree_learner.h"
#include "gpu_tree_learner.h"
#include "linear_tree_learner.h"
#include "parallel_tree_learner.h"
#include "serial_tree_learner.h"
#include "linear_tree_learner.h"

namespace LightGBM {

Expand Down
22 changes: 11 additions & 11 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,18 @@ def test_save_and_load_linear(self):
X_train[:X_train.shape[0] // 2, 0] = 0
y_train[:X_train.shape[0] // 2] = 1
params = {'linear_tree': True}
train_data = lgb.Dataset(X_train, label=y_train, params=params)
est = lgb.train(params, train_data, num_boost_round=10, categorical_feature=[0])
pred1 = est.predict(X_train)
train_data.save_binary('temp_dataset.bin')
train_data_1 = lgb.Dataset(X_train, label=y_train, params=params)
est_1 = lgb.train(params, train_data_1, num_boost_round=10, categorical_feature=[0])
pred_1 = est_1.predict(X_train)
train_data_1.save_binary('temp_dataset.bin')
train_data_2 = lgb.Dataset('temp_dataset.bin')
est = lgb.train(params, train_data_2, num_boost_round=10)
pred2 = est.predict(X_train)
np.testing.assert_allclose(pred1, pred2)
est.save_model('temp_model.txt')
est2 = lgb.Booster(model_file='temp_model.txt')
pred3 = est2.predict(X_train)
np.testing.assert_allclose(pred2, pred3)
est_2 = lgb.train(params, train_data_2, num_boost_round=10)
pred_2 = est_2.predict(X_train)
np.testing.assert_allclose(pred_1, pred_2)
est_2.save_model('temp_model.txt')
est_3 = lgb.Booster(model_file='temp_model.txt')
pred_3 = est_3.predict(X_train)
np.testing.assert_allclose(pred_2, pred_3)

def test_subset_group(self):
X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)),
Expand Down
35 changes: 19 additions & 16 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,6 +2232,7 @@ def test_dataset_update_params(self):
"group_column": 0,
"ignore_column": 0,
"min_data_in_leaf": 10,
"linear_tree": False,
"verbose": -1}
unchangeable_params = {"max_bin": 150,
"max_bin_by_feature": [30, 5],
Expand All @@ -2252,7 +2253,8 @@ def test_dataset_update_params(self):
"group_column": 1,
"ignore_column": 1,
"forcedbins_filename": "/some/path/forcedbins.json",
"min_data_in_leaf": 2}
"min_data_in_leaf": 2,
"linear_tree": True}
X = np.random.random((100, 2))
y = np.random.random(100)

Expand Down Expand Up @@ -2420,45 +2422,46 @@ def test_interaction_constraints(self):
[1] + list(range(2, num_features))]),
train_data, num_boost_round=10)

def test_linear(self):
# check that setting boosting=gbdt_linear fits better than boosting=gbdt when data has linear relationship
def test_linear_trees(self):
# check that setting linear_tree=True fits better than ordinary trees when data has linear relationship
np.random.seed(0)
x = np.arange(0, 100, 0.1)
y = 2 * x + np.random.normal(0, 0.1, len(x))
lgb_train = lgb.Dataset(x[:, np.newaxis], label=y)
x = x[:, np.newaxis]
lgb_train = lgb.Dataset(x, label=y)
params = {'verbose': -1,
'metric': 'mse',
'seed': 0,
'num_leaves': 2}
est = lgb.train(params, lgb_train, num_boost_round=10)
pred1 = est.predict(x[:, np.newaxis])
lgb_train = lgb.Dataset(x[:, np.newaxis], label=y)
pred1 = est.predict(x)
lgb_train = lgb.Dataset(x, label=y)
res = {}
est = lgb.train(dict(params, linear_tree=True), lgb_train, num_boost_round=10, evals_result=res,
valid_sets=[lgb_train], valid_names=['train'])
pred2 = est.predict(x[:, np.newaxis])
pred2 = est.predict(x)
np.testing.assert_allclose(res['train']['l2'][-1], mean_squared_error(y, pred2), atol=10**(-1))
self.assertLess(mean_squared_error(y, pred2), mean_squared_error(y, pred1))
# test again with nans in data
x[:10] = np.nan
lgb_train = lgb.Dataset(x[:, np.newaxis], label=y)
lgb_train = lgb.Dataset(x, label=y)
est = lgb.train(params, lgb_train, num_boost_round=10)
pred1 = est.predict(x[:, np.newaxis])
lgb_train = lgb.Dataset(x[:, np.newaxis], label=y)
pred1 = est.predict(x)
lgb_train = lgb.Dataset(x, label=y)
res = {}
est = lgb.train(dict(params, linear_tree=True), lgb_train, num_boost_round=10, evals_result=res,
valid_sets=[lgb_train], valid_names=['train'])
pred2 = est.predict(x[:, np.newaxis])
pred2 = est.predict(x)
np.testing.assert_allclose(res['train']['l2'][-1], mean_squared_error(y, pred2), atol=10**(-1))
self.assertLess(mean_squared_error(y, pred2), mean_squared_error(y, pred1))
# test again with bagging
res = {}
est = lgb.train(dict(params, linear_tree=True, subsample=0.8, bagging_freq=1), lgb_train,
num_boost_round=10, evals_result=res, valid_sets=[lgb_train], valid_names=['train'])
pred = est.predict(x[:, np.newaxis])
pred = est.predict(x)
np.testing.assert_allclose(res['train']['l2'][-1], mean_squared_error(y, pred), atol=10**(-1))
# test with a feature that has only one non-nan value
x = np.concatenate([np.ones([x.shape[0], 1]), x[:, np.newaxis]], 1)
x = np.concatenate([np.ones([x.shape[0], 1]), x], 1)
x[500:, 1] = np.nan
y[500:] += 10
lgb_train = lgb.Dataset(x, label=y)
Expand Down Expand Up @@ -2486,11 +2489,11 @@ def test_linear(self):
p2 = est2.predict(x)
self.assertLess(np.mean(np.abs(p1 - p2)), 2)
# test refit: different results training on different data
est2 = est.refit(x[:100, :], label=y[:100])
p3 = est2.predict(x)
est3 = est.refit(x[:100, :], label=y[:100])
p3 = est3.predict(x)
self.assertGreater(np.mean(np.abs(p2 - p1)), np.abs(np.max(p3 - p1)))
# test when num_leaves - 1 < num_features and when num_leaves - 1 > num_features
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, random_state=2)
X_train, _, y_train, _ = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, random_state=2)
params = {'linear_tree': True,
'verbose': -1,
'metric': 'mse',
Expand Down

0 comments on commit 5a46084

Please sign in to comment.