Skip to content

Commit

Permalink
Merge pull request #16 from titsuki/train-parameter
Browse files Browse the repository at this point in the history
Make XGBoost.train enable to set parameter
  • Loading branch information
titsuki authored Aug 10, 2021
2 parents bcf78c0 + 97062e1 commit c31bea5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ METHODS

Defined as:

method train(Algorithm::XGBoost::DMatrix $dmat, Int $num-iteration --> Algorithm::XGBoost::Model)
method train(Algorithm::XGBoost::DMatrix $dmat, Int $num-iteration, %param --> Algorithm::XGBoost::Model)

Trains a XGBoost model.

* `$dmat` The instance of Algorithm::XGBoost::DMatrix.

* `$num-iteration` The number of iterations for training.

* `%param` The parameter for training.

### version

Defined as:
Expand Down
10 changes: 8 additions & 2 deletions lib/Algorithm/XGBoost.rakumod
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@ method version(--> Version) {

my sub XGBoosterCreate(Algorithm::XGBoost::DMatrix is rw, ulong, Algorithm::XGBoost::Booster is rw --> int32) is native($library) { * }
my sub XGBoosterUpdateOneIter(Algorithm::XGBoost::Booster, int32, Algorithm::XGBoost::DMatrix --> int32) is native($library) { * }
my sub XGBoosterSetParam(Algorithm::XGBoost::Booster, Str, Str --> int32) is native($library) { * }

method train(Algorithm::XGBoost::DMatrix $dmat, Int $num-iteration --> Algorithm::XGBoost::Model) {
method train(Algorithm::XGBoost::DMatrix $dmat, Int $num-iteration, %param? --> Algorithm::XGBoost::Model) {
my $h = Pointer.new;
XGBoosterCreate($dmat, 1, $h);
my $booster = nativecast(Algorithm::XGBoost::Booster, $h);
for %param {
XGBoosterSetParam($booster, .key.Str, .value.Str);
}

for ^$num-iteration -> $iter {
XGBoosterUpdateOneIter($booster, $iter, $dmat);
Expand Down Expand Up @@ -84,14 +88,16 @@ Algorithm::XGBoost is a Raku bindings for XGBoost ( https://github.com/dmlc/xgbo
Defined as:
method train(Algorithm::XGBoost::DMatrix $dmat, Int $num-iteration --> Algorithm::XGBoost::Model)
method train(Algorithm::XGBoost::DMatrix $dmat, Int $num-iteration, %param --> Algorithm::XGBoost::Model)
Trains a XGBoost model.
=item C<$dmat> The instance of Algorithm::XGBoost::DMatrix.
=item C<$num-iteration> The number of iterations for training.
=item C<%param> The parameter for training.
=head3 version
Defined as:
Expand Down
12 changes: 12 additions & 0 deletions t/01-basic.rakutest
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,16 @@ subtest {
is $actual, $expected;
}, "When a model is given, Then .save/.load should retain the model";

subtest {
Algorithm::XGBoost.global-config(q[{"verbosity": 4}]);
my @train[3;2] = [[0e0,0e0],[0e0,1e0],[1e0,0e0]];
my @y = [1e0, 0e0, 1e0];
my $dmat = Algorithm::XGBoost::DMatrix.from-matrix(@train, @y);
# TODO: Get param from the model
lives-ok {
my %param = (:booster("dart"));
my $model = Algorithm::XGBoost.train($dmat, 10, %param);
}
}, "When a %param is given, Then .train should use it";

done-testing;

0 comments on commit c31bea5

Please sign in to comment.