Skip to content

Commit 0fc9849

Browse files
author
Ben Auffarth
committed
make public to use v2 interface
1 parent 1a384f6 commit 0fc9849

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

src/main/java/org/lightgbm/predict4j/v2/Boosting.java

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public abstract class Boosting implements Serializable {
1818
private static final Logger logger = LoggerFactory.getLogger(Boosting.class);
1919
private static final long serialVersionUID = -3370589073161617590L;
2020

21-
static Boosting createBoosting(String filename) throws FileNotFoundException, IOException {
21+
public static Boosting createBoosting(String filename) throws IOException {
2222
String type = getBoostingTypeFromModelFile(filename);
2323
Boosting boosting = null;
2424
if (type.equals("tree")) {
@@ -30,7 +30,7 @@ static Boosting createBoosting(String filename) throws FileNotFoundException, IO
3030
return boosting;
3131
}
3232

33-
static Boosting createBoosting(String type, String filename) throws FileNotFoundException, IOException {
33+
public static Boosting createBoosting(String type, String filename) throws FileNotFoundException, IOException {
3434
if (filename == null || filename.length() == 0) {
3535
if (type.equals("gbdt")) {
3636
return new GBDT();
@@ -62,7 +62,7 @@ static Boosting createBoosting(String type, String filename) throws FileNotFound
6262
}
6363
}
6464

65-
static boolean loadFileToBoosting(Boosting boosting, String filename) throws FileNotFoundException, IOException {
65+
public static boolean loadFileToBoosting(Boosting boosting, String filename) throws IOException {
6666
if (boosting != null) {
6767
StringBuilder sb = new StringBuilder();
6868
List<String> lines = IOUtils.readLines(new FileInputStream(filename));
@@ -76,28 +76,28 @@ static boolean loadFileToBoosting(Boosting boosting, String filename) throws Fil
7676
return true;
7777
}
7878

79-
static String getBoostingTypeFromModelFile(String filename) throws FileNotFoundException, IOException {
79+
public static String getBoostingTypeFromModelFile(String filename) throws IOException {
8080
List<String> lines = IOUtils.readLines(new FileInputStream(filename));
8181
return lines.get(0);
8282
}
8383

84-
abstract boolean loadModelFromString(String modelStr);
84+
abstract public boolean loadModelFromString(String modelStr);
8585

86-
abstract boolean needAccuratePrediction();
86+
abstract public boolean needAccuratePrediction();
8787

88-
abstract int numberOfClasses();
88+
abstract public int numberOfClasses();
8989

90-
abstract void initPredict(int num_iteration);
90+
abstract public void initPredict(int num_iteration);
9191

92-
abstract int numPredictOneRow(int num_iteration, boolean is_pred_leaf);
92+
abstract public int numPredictOneRow(int num_iteration, boolean is_pred_leaf);
9393

94-
abstract int getCurrentIteration();
94+
abstract public int getCurrentIteration();
9595

96-
abstract int maxFeatureIdx();
96+
abstract public int maxFeatureIdx();
9797

98-
abstract List<Double> predictLeafIndex(SparseVector vector);
98+
abstract public List<Double> predictLeafIndex(SparseVector vector);
9999

100-
abstract List<Double> predictRaw(SparseVector vector, PredictionEarlyStopInstance early_stop);
100+
abstract public List<Double> predictRaw(SparseVector vector, PredictionEarlyStopInstance early_stop);
101101

102-
abstract List<Double> predict(SparseVector vector, PredictionEarlyStopInstance early_stop);
102+
abstract public List<Double> predict(SparseVector vector, PredictionEarlyStopInstance early_stop);
103103
}

src/main/java/org/lightgbm/predict4j/v2/GBDT.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ public class GBDT extends Boosting {
3434
/* ! \brief current iteration */
3535
int iter_;
3636

37-
boolean loadModelFromString(String model_str) {
37+
@Override
38+
public boolean loadModelFromString(String model_str) {
3839
// use serialized string to restore this object
3940
models_.clear();
4041
String[] lines = model_str.split("\n");
@@ -134,26 +135,26 @@ boolean loadModelFromString(String model_str) {
134135
return true;
135136
}
136137

137-
boolean needAccuratePrediction() {
138+
public boolean needAccuratePrediction() {
138139
if (objective_function_ == null) {
139140
return true;
140141
} else {
141142
return objective_function_.needAccuratePrediction();
142143
}
143144
}
144145

145-
int numberOfClasses() {
146+
public int numberOfClasses() {
146147
return num_class_;
147148
}
148149

149-
void initPredict(int num_iteration) {
150+
public void initPredict(int num_iteration) {
150151
num_iteration_for_pred_ = models_.size() / num_tree_per_iteration_;
151152
if (num_iteration > 0) {
152153
num_iteration_for_pred_ = Math.min(num_iteration + (boost_from_average_ ? 1 : 0), num_iteration_for_pred_);
153154
}
154155
}
155156

156-
int numPredictOneRow(int num_iteration, boolean is_pred_leaf) {
157+
public int numPredictOneRow(int num_iteration, boolean is_pred_leaf) {
157158
int num_preb_in_one_row = num_class_;
158159
if (is_pred_leaf) {
159160
int max_iteration = getCurrentIteration();
@@ -166,15 +167,15 @@ int numPredictOneRow(int num_iteration, boolean is_pred_leaf) {
166167
return num_preb_in_one_row;
167168
}
168169

169-
int getCurrentIteration() {
170+
public int getCurrentIteration() {
170171
return models_.size() / num_tree_per_iteration_;
171172
}
172173

173-
int maxFeatureIdx() {
174+
public int maxFeatureIdx() {
174175
return max_feature_idx_;
175176
}
176177

177-
List<Double> predictLeafIndex(SparseVector vector) {
178+
public List<Double> predictLeafIndex(SparseVector vector) {
178179
List<Double> outputs=new ArrayList<>();
179180
for (int i = 0; i < num_iteration_for_pred_; ++i) {
180181
for (int j = 0; j < num_class_; ++j) {
@@ -184,7 +185,7 @@ List<Double> predictLeafIndex(SparseVector vector) {
184185
return outputs;
185186
}
186187

187-
List<Double> predictRaw(SparseVector features, PredictionEarlyStopInstance early_stop) {
188+
public List<Double> predictRaw(SparseVector features, PredictionEarlyStopInstance early_stop) {
188189
double[] output = new double[num_class_];
189190
int early_stop_round_counter = 0;
190191
for (int i = 0; i < num_iteration_for_pred_; ++i) {
@@ -204,7 +205,7 @@ List<Double> predictRaw(SparseVector features, PredictionEarlyStopInstance early
204205
return ret;
205206
}
206207

207-
List<Double> predict(SparseVector features, PredictionEarlyStopInstance early_stop) {
208+
public List<Double> predict(SparseVector features, PredictionEarlyStopInstance early_stop) {
208209
List<Double> ret=predictRaw(features, early_stop);;
209210
if (objective_function_ != null) {
210211
double[]output=new double[ret.size()];

0 commit comments

Comments
 (0)