Skip to content

Commit a8b21f0

Browse files
committed
initial commit
0 parents  commit a8b21f0

File tree

8 files changed

+1425
-0
lines changed

8 files changed

+1425
-0
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
* text=auto

.gitignore

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
target/
2+
!.mvn/wrapper/maven-wrapper.jar
3+
!**/src/main/**/target/
4+
!**/src/test/**/target/
5+
6+
### IntelliJ IDEA ###
7+
.idea/
8+
*.iws
9+
*.iml
10+
*.ipr
11+
12+
### Eclipse ###
13+
.apt_generated
14+
.classpath
15+
.factorypath
16+
.project
17+
.settings
18+
.springBeans
19+
.sts4-cache
20+
21+
### NetBeans ###
22+
/nbproject/private/
23+
/nbbuild/
24+
/dist/
25+
/nbdist/
26+
/.nb-gradle/
27+
build/
28+
!**/src/main/**/build/
29+
!**/src/test/**/build/
30+
31+
### VS Code ###
32+
.vscode/
33+
34+
### Mac OS ###
35+
.DS_Store

pom.xml

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
7+
<groupId>io.github.alexcheng1982</groupId>
8+
<artifactId>bird-classifier</artifactId>
9+
<version>1.0.0-SNAPSHOT</version>
10+
<name>Bird classifier</name>
11+
12+
<properties>
13+
<maven.compiler.source>21</maven.compiler.source>
14+
<maven.compiler.target>21</maven.compiler.target>
15+
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
16+
</properties>
17+
18+
<dependencies>
19+
<dependency>
20+
<groupId>com.microsoft.onnxruntime</groupId>
21+
<artifactId>onnxruntime</artifactId>
22+
<version>1.19.2</version>
23+
</dependency>
24+
<dependency>
25+
<groupId>com.fasterxml.jackson.core</groupId>
26+
<artifactId>jackson-databind</artifactId>
27+
<version>2.18.0</version>
28+
</dependency>
29+
<dependency>
30+
<groupId>info.picocli</groupId>
31+
<artifactId>picocli</artifactId>
32+
<version>4.7.6</version>
33+
</dependency>
34+
<dependency>
35+
<groupId>org.slf4j</groupId>
36+
<artifactId>slf4j-api</artifactId>
37+
<version>2.0.16</version>
38+
</dependency>
39+
<dependency>
40+
<groupId>ch.qos.logback</groupId>
41+
<artifactId>logback-classic</artifactId>
42+
<version>1.5.8</version>
43+
</dependency>
44+
</dependencies>
45+
46+
<build>
47+
<plugins>
48+
<plugin>
49+
<artifactId>maven-assembly-plugin</artifactId>
50+
<version>3.7.1</version>
51+
<configuration>
52+
<finalName>bird-classifier</finalName>
53+
<descriptorRefs>
54+
<descriptorRef>jar-with-dependencies</descriptorRef>
55+
</descriptorRefs>
56+
<archive>
57+
<manifest>
58+
<mainClass>io.github.alexcheng1982.birdclassifier.Cli
59+
</mainClass>
60+
</manifest>
61+
</archive>
62+
</configuration>
63+
<executions>
64+
<execution>
65+
<id>make-assembly</id>
66+
<phase>package</phase>
67+
<goals>
68+
<goal>single</goal>
69+
</goals>
70+
</execution>
71+
</executions>
72+
</plugin>
73+
</plugins>
74+
</build>
75+
76+
</project>
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package io.github.alexcheng1982.birdclassifier;
2+
3+
import ai.onnxruntime.OnnxTensor;
4+
import ai.onnxruntime.OrtEnvironment;
5+
import ai.onnxruntime.OrtException;
6+
import ai.onnxruntime.OrtSession;
7+
import ai.onnxruntime.OrtSession.SessionOptions;
8+
import com.fasterxml.jackson.core.type.TypeReference;
9+
import com.fasterxml.jackson.databind.ObjectMapper;
10+
import java.awt.Color;
11+
import java.io.IOException;
12+
import java.io.InputStream;
13+
import java.net.URL;
14+
import java.nio.FloatBuffer;
15+
import java.util.Map;
16+
import javax.imageio.ImageIO;
17+
import org.slf4j.Logger;
18+
import org.slf4j.LoggerFactory;
19+
20+
public class Classifier {
21+
22+
private static final OrtEnvironment env = OrtEnvironment.getEnvironment();
23+
24+
private final int imageWidth;
25+
private final int imageHeight;
26+
27+
private static final Logger LOGGER = LoggerFactory.getLogger("Classifier");
28+
29+
public Classifier(int imageWidth, int imageHeight) {
30+
this.imageWidth = imageWidth;
31+
this.imageHeight = imageHeight;
32+
}
33+
34+
public String classify(URL url) throws OrtException, IOException {
35+
LOGGER.info("Classify {}", url);
36+
try (OrtSession.SessionOptions options = new SessionOptions();
37+
InputStream modelStream = getClass().getResourceAsStream("/model.onnx");
38+
OrtSession session = env.createSession(modelStream.readAllBytes(),
39+
options);
40+
InputStream configJsonStream = getClass().getResourceAsStream(
41+
"/config.json")
42+
) {
43+
var tensor = imageDataToTensor(url);
44+
var inputName = session.getInputNames().stream().toList().getFirst();
45+
var outputData = session.run(Map.of(
46+
inputName, tensor
47+
));
48+
var objectMapper = new ObjectMapper();
49+
var config = objectMapper.readValue(
50+
configJsonStream,
51+
new TypeReference<Map<String, Object>>() {
52+
});
53+
var id2label = (Map<?, ?>) config.get("id2label");
54+
try (var output = outputData.get(0)) {
55+
float[][] values = (float[][]) output.getValue();
56+
var result = argmax(values[0]);
57+
return (String) id2label.get(Integer.toString(result));
58+
}
59+
}
60+
}
61+
62+
OnnxTensor imageDataToTensor(URL url) throws IOException, OrtException {
63+
var bufferedImage = ImageUtils.resizeImage(ImageIO.read(url), imageWidth,
64+
imageHeight);
65+
var height = bufferedImage.getHeight();
66+
var width = bufferedImage.getWidth();
67+
var size = width * height;
68+
var r = new int[size];
69+
var g = new int[size];
70+
var b = new int[size];
71+
var index = 0;
72+
for (int h = 0; h < height; h++) {
73+
for (int w = 0; w < width; w++) {
74+
var color = new Color(bufferedImage.getRGB(w, h));
75+
r[index] = color.getRed();
76+
g[index] = color.getGreen();
77+
b[index] = color.getBlue();
78+
index++;
79+
}
80+
}
81+
var data = new int[r.length + g.length + b.length];
82+
System.arraycopy(r, 0, data, 0, r.length);
83+
System.arraycopy(g, 0, data, r.length, g.length);
84+
System.arraycopy(b, 0, data, r.length + g.length, b.length);
85+
int total = data.length;
86+
float[] float32Data = new float[total];
87+
for (int i = 0; i < total; i++) {
88+
float32Data[i] = (float) (data[i] / 255.0);
89+
}
90+
return OnnxTensor.createTensor(env, FloatBuffer.wrap(float32Data),
91+
new long[]{1, 3, imageWidth, imageWidth});
92+
}
93+
94+
int argmax(float[] floatData) {
95+
var index = 0;
96+
var max = Float.MIN_VALUE;
97+
for (int i = 0; i < floatData.length; i++) {
98+
if (floatData[i] > max) {
99+
max = floatData[i];
100+
index = i;
101+
}
102+
}
103+
return index;
104+
}
105+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package io.github.alexcheng1982.birdclassifier;
2+
3+
import java.net.URL;
4+
import java.util.concurrent.Callable;
5+
import org.slf4j.Logger;
6+
import org.slf4j.LoggerFactory;
7+
import picocli.CommandLine;
8+
import picocli.CommandLine.Parameters;
9+
10+
@CommandLine.Command(
11+
name = "bird-classifier",
12+
mixinStandardHelpOptions = true,
13+
version = "0.1.0",
14+
description = "Classify birds"
15+
)
16+
public class Cli implements Callable<String> {
17+
18+
private static final Logger LOGGER = LoggerFactory.getLogger("Cli");
19+
20+
@Parameters(index = "0")
21+
URL imageUrl;
22+
23+
@CommandLine.Option(
24+
names = {"-w", "--width"},
25+
defaultValue = "260",
26+
description = "Image width"
27+
)
28+
int imageWidth = 260;
29+
30+
@CommandLine.Option(
31+
names = {"-h", "--height"},
32+
defaultValue = "260",
33+
description = "Image height"
34+
)
35+
int imageHeight = 260;
36+
37+
@Override
38+
public String call() throws Exception {
39+
return new Classifier(imageWidth, imageHeight).classify(imageUrl)
40+
.toLowerCase();
41+
}
42+
43+
public static void main(String[] args) {
44+
try {
45+
var cmd = new CommandLine(new Cli());
46+
int exitCode = cmd.execute(args);
47+
String result = cmd.getExecutionResult();
48+
System.out.printf("%nClassified as : %s%n", result);
49+
System.exit(exitCode);
50+
} catch (Exception e) {
51+
LOGGER.error("Internal error: {}", e.getMessage());
52+
}
53+
}
54+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package io.github.alexcheng1982.birdclassifier;
2+
3+
import java.awt.image.BufferedImage;
4+
5+
public class ImageUtils {
6+
7+
public static BufferedImage resizeImage(BufferedImage sourceImage,
8+
int targetWidth,
9+
int targetHeight) {
10+
var image = sourceImage.getScaledInstance(targetWidth, targetHeight,
11+
BufferedImage.SCALE_AREA_AVERAGING);
12+
var outputImage = new BufferedImage(targetWidth, targetHeight,
13+
BufferedImage.TYPE_INT_ARGB);
14+
outputImage.getGraphics().drawImage(image, 0, 0, null);
15+
return outputImage;
16+
}
17+
}

0 commit comments

Comments
 (0)