Skip to content

Commit ed5d8a3

Browse files
author
hzlinyanggang
committed
predicting code
1 parent 2c55dc0 commit ed5d8a3

File tree

18 files changed

+37706
-0
lines changed

18 files changed

+37706
-0
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
*.class
2+
.classpath
3+
.project
4+
/.settings/
25

36
# Mobile Tools for Java (J2ME)
47
.mtj.tmp/
@@ -10,3 +13,4 @@
1013

1114
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
1215
hs_err_pid*
16+
/target/

pom.xml

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
<?xml version="1.0"?>
2+
<project
3+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"
4+
xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
5+
<modelVersion>4.0.0</modelVersion>
6+
7+
<groupId>org.lightgbm.predict4j</groupId>
8+
<artifactId>lightgbm_predict4j</artifactId>
9+
<version>1.0</version>
10+
<packaging>jar</packaging>
11+
12+
<name>lightgbm_predict4j</name>
13+
<url>http://maven.apache.org</url>
14+
15+
<dependencies>
16+
<dependency>
17+
<groupId>commons-io</groupId>
18+
<artifactId>commons-io</artifactId>
19+
<version>2.4</version>
20+
</dependency>
21+
<!-- https://mvnrepository.com/artifact/org.slf4j/slf4j-api -->
22+
<dependency>
23+
<groupId>org.slf4j</groupId>
24+
<artifactId>slf4j-api</artifactId>
25+
<version>1.7.21</version>
26+
</dependency>
27+
<dependency>
28+
<groupId>org.slf4j</groupId>
29+
<artifactId>log4j-over-slf4j</artifactId>
30+
<version>1.7.7</version>
31+
</dependency>
32+
33+
<dependency>
34+
<groupId>org.slf4j</groupId>
35+
<artifactId>log4j-over-slf4j</artifactId>
36+
<version>1.7.7</version>
37+
</dependency>
38+
<dependency>
39+
<groupId>ch.qos.logback</groupId>
40+
<artifactId>logback-core</artifactId>
41+
<version>1.1.2</version>
42+
</dependency>
43+
<dependency>
44+
<groupId>ch.qos.logback</groupId>
45+
<artifactId>logback-access</artifactId>
46+
<version>1.1.2</version>
47+
</dependency>
48+
<dependency>
49+
<groupId>ch.qos.logback</groupId>
50+
<artifactId>logback-classic</artifactId>
51+
<version>1.1.2</version>
52+
</dependency>
53+
<dependency>
54+
<groupId>junit</groupId>
55+
<artifactId>junit</artifactId>
56+
<version>4.12</version>
57+
<scope>test</scope>
58+
</dependency>
59+
</dependencies>
60+
61+
<build>
62+
<finalName>lightgbm_predict4j</finalName>
63+
<resources>
64+
<resource>
65+
<directory>src/main/resources</directory>
66+
</resource>
67+
</resources>
68+
<testResources>
69+
<testResource>
70+
<directory>src/test/resources</directory>
71+
</testResource>
72+
</testResources>
73+
<plugins>
74+
<plugin>
75+
<groupId>org.apache.maven.plugins</groupId>
76+
<artifactId>maven-compiler-plugin</artifactId>
77+
<configuration>
78+
<compilerVersion>1.7</compilerVersion>
79+
<source>1.7</source>
80+
<target>1.7</target>
81+
</configuration>
82+
</plugin>
83+
</plugins>
84+
</build>
85+
86+
</project>
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package org.lightgbm.predict4j;
2+
3+
import java.io.FileInputStream;
4+
import java.io.FileNotFoundException;
5+
import java.io.IOException;
6+
import java.util.HashMap;
7+
import java.util.List;
8+
import java.util.Map;
9+
10+
import org.apache.commons.io.IOUtils;
11+
import org.slf4j.Logger;
12+
import org.slf4j.LoggerFactory;
13+
14+
/**
15+
* @author lyg5623
16+
*/
17+
public class Application {
18+
private static final Logger logger = LoggerFactory.getLogger(Application.class);
19+
private OverallConfig config = new OverallConfig();
20+
21+
public static void main(String[] args) throws FileNotFoundException, IOException {
22+
args = "config=cluster_test.conf".split("\\s+");
23+
String modelPath = "LightGBM_model.txt";
24+
Application app = new Application(args);
25+
String dataPath = "lightgbm_test.txt";
26+
String outputPath = "LightGBM_predict_result.txt";
27+
app.run(modelPath, dataPath, outputPath);
28+
}
29+
30+
public Application(String[] argv) throws FileNotFoundException, IOException {
31+
loadParameters(argv);
32+
}
33+
34+
private void loadParameters(String[] argv) throws FileNotFoundException, IOException {
35+
Map<String, String> params = new HashMap<String, String>();
36+
for (int i = 0; i < argv.length; ++i) {
37+
String[] tmp_strs = argv[i].split("=");
38+
if (tmp_strs.length == 2) {
39+
String key = tmp_strs[0].trim();
40+
String value = tmp_strs[1].trim();
41+
if (key.length() <= 0) {
42+
continue;
43+
}
44+
params.put(key, value);
45+
} else {
46+
logger.warn(String.format("Unknown parameter in command line: %s", argv[i]));
47+
}
48+
}
49+
// check for alias
50+
ParameterAlias.keyAliasTransform(params);
51+
// read parameters from config file
52+
if (params.containsKey("config_file")) {
53+
List<String> lines = IOUtils.readLines(new FileInputStream(params.get("config_file")));
54+
if (lines != null) {
55+
for (String line : lines) {
56+
line = line.trim();
57+
// remove str after "#"
58+
if (line.startsWith("#"))
59+
continue;
60+
if (line.length() == 0) {
61+
continue;
62+
}
63+
String[] tmp_strs = line.split("=");
64+
if (tmp_strs.length == 2) {
65+
String key = tmp_strs[0].trim();
66+
String value = tmp_strs[1].trim();
67+
if (key.length() <= 0) {
68+
continue;
69+
}
70+
// Command-line has higher priority
71+
if (!params.containsKey(key))
72+
params.put(key, value);
73+
} else {
74+
logger.warn("Unknown parameter in config file: " + line);
75+
}
76+
}
77+
}
78+
}
79+
// check for alias again
80+
ParameterAlias.keyAliasTransform(params);
81+
// load configs
82+
config.set(params);
83+
logger.info("Finished loading parameters");
84+
}
85+
86+
87+
private void run(String modelPath, String dataPath, String outputPath) throws FileNotFoundException, IOException {
88+
Boosting boosting = initPredict(modelPath);
89+
predict(dataPath, outputPath, boosting);
90+
}
91+
92+
private Boosting initPredict(String modelPath) throws FileNotFoundException, IOException {
93+
Boosting boosting = Boosting.createBoosting(modelPath);
94+
logger.info("Finished initializing prediction");
95+
return boosting;
96+
}
97+
98+
private void predict(String dataPath, String outputPath, Boosting boosting) throws IOException {
99+
boosting.setNumIterationForPred(config.getIoConfig().getNumIterationPredict());
100+
// create predictor
101+
Predictor predictor = new Predictor(boosting, config.getIoConfig().isPredictRawScore(),
102+
config.getIoConfig().isPredictLeafIndex());
103+
predictor.Predict(dataPath, outputPath);
104+
logger.info("Finished prediction");
105+
}
106+
107+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package org.lightgbm.predict4j;
2+
3+
import java.io.FileInputStream;
4+
import java.io.FileNotFoundException;
5+
import java.io.IOException;
6+
import java.io.Serializable;
7+
import java.util.List;
8+
9+
import org.apache.commons.io.IOUtils;
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
12+
13+
/**
14+
* @author lyg5623
15+
*/
16+
public abstract class Boosting implements Serializable {
17+
private static final Logger logger = LoggerFactory.getLogger(Boosting.class);
18+
private static final long serialVersionUID = -661844499486913306L;
19+
20+
public static Boosting createBoosting(String modelPath) throws FileNotFoundException, IOException {
21+
Boosting boosting = null;
22+
String type = getBoostingTypeFromModelFile(modelPath);
23+
if (type.equals("tree")) {
24+
boosting = new GBDT();
25+
} else {
26+
logger.error("unknow submodel type in model file " + modelPath);
27+
}
28+
loadFileToBoosting(boosting, modelPath);
29+
return boosting;
30+
}
31+
32+
33+
private static boolean loadFileToBoosting(Boosting boosting, String modelPath)
34+
throws FileNotFoundException, IOException {
35+
if (boosting != null) {
36+
StringBuilder sb = new StringBuilder();
37+
List<String> lines = IOUtils.readLines(new FileInputStream(modelPath));
38+
for (String line : lines) {
39+
sb.append(line).append("\n");
40+
}
41+
if (!boosting.loadModelFromString(sb.toString()))
42+
return false;
43+
}
44+
45+
return true;
46+
}
47+
48+
public abstract boolean loadModelFromString(String modelStr);
49+
50+
private static String getBoostingTypeFromModelFile(String modelPath) throws FileNotFoundException, IOException {
51+
List<String> lines = IOUtils.readLines(new FileInputStream(modelPath));
52+
return lines.get(0);
53+
}
54+
55+
public abstract void setNumIterationForPred(int numIteration);
56+
57+
public abstract List<Integer> predictLeafIndex(SparseVector vector);
58+
59+
public abstract List<Double> predictRaw(SparseVector vector);
60+
61+
public abstract List<Double> predict(SparseVector vector);
62+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package org.lightgbm.predict4j;
2+
3+
import java.util.List;
4+
5+
import org.slf4j.Logger;
6+
import org.slf4j.LoggerFactory;
7+
8+
/**
9+
* @author lyg5623
10+
*/
11+
public class Common {
12+
private static final Logger logger = LoggerFactory.getLogger(Common.class);
13+
14+
public static String findFromLines(String[] lines, String keyWord) {
15+
for (String line : lines) {
16+
if (line.contains(keyWord))
17+
return line;
18+
}
19+
return "";
20+
}
21+
22+
public static String join(String[] strs, String delimiter) {
23+
if (strs == null || strs.length == 0) {
24+
return "";
25+
}
26+
StringBuilder strBuf = new StringBuilder();
27+
strBuf.append(strs[0]);
28+
for (int i = 1; i < strs.length; ++i) {
29+
strBuf.append(delimiter).append(strs[i]);
30+
}
31+
return strBuf.toString();
32+
}
33+
34+
public static String join(List<Double> strs, String delimiter) {
35+
if (strs == null || strs.size() == 0) {
36+
return "";
37+
}
38+
StringBuilder strBuf = new StringBuilder();
39+
strBuf.append(strs.get(0));
40+
for (int i = 1; i < strs.size(); ++i) {
41+
strBuf.append(delimiter).append(strs.get(i));
42+
}
43+
return strBuf.toString();
44+
}
45+
46+
public static String join(String[] strs, int start, int end, String delimiter) {
47+
if (end - start <= 0) {
48+
return "";
49+
}
50+
start = Math.min(start, strs.length - 1);
51+
end = Math.min(end, strs.length);
52+
StringBuilder strBuf = new StringBuilder();
53+
strBuf.append(strs[start]);
54+
for (int i = start + 1; i < end; ++i) {
55+
strBuf.append(delimiter).append(strs[i]);
56+
}
57+
return strBuf.toString();
58+
}
59+
60+
61+
public static int[] stringToArrayInt(String str, String delimiter, int n) {
62+
String[] strs = str.split(delimiter);
63+
if (strs.length != n) {
64+
logger.error("StringToArray error, size doesn't match.");
65+
}
66+
int[] ret = new int[n];
67+
for (int i = 0; i < n; ++i) {
68+
ret[i] = Integer.parseInt(strs[i]);
69+
}
70+
return ret;
71+
}
72+
73+
public static double[] stringToArrayDouble(String str, String delimiter, int n) {
74+
String[] strs = str.split(delimiter);
75+
if (strs.length != n) {
76+
logger.error("StringToArray error, size doesn't match.");
77+
}
78+
double[] ret = new double[n];
79+
for (int i = 0; i < n; ++i) {
80+
ret[i] = Double.parseDouble(strs[i]);
81+
}
82+
return ret;
83+
}
84+
85+
86+
/*
87+
* ! \brief Do inplace softmax transformaton on p_rec \param p_rec The input/output vector of the values.
88+
*/
89+
public static void softmax(double[] pRec) {
90+
double[] rec = pRec;
91+
double wmax = rec[0];
92+
for (int i = 1; i < rec.length; ++i) {
93+
wmax = Math.max(rec[i], wmax);
94+
}
95+
double wsum = 0.0f;
96+
for (int i = 0; i < rec.length; ++i) {
97+
rec[i] = Math.exp(rec[i] - wmax);
98+
wsum += rec[i];
99+
}
100+
for (int i = 0; i < rec.length; ++i) {
101+
rec[i] /= wsum;
102+
}
103+
}
104+
}

0 commit comments

Comments
 (0)