A pure Java implementation of OpenAI's gpt-oss inference in ~1000 lines of code optimized for CPU execution.
Inspired by llama.cpp, llama2.c, this repo ports the gpt-oss PyTorch model.py to efficient Java code, emphasizing minimalism, simplicity, and educational purpose.
- Pure Java - No native dependencies, runs anywhere with Java 23+
- Complete gpt-oss architecture - Full implementation of MoE transformer with GQA, sliding window attention, RoPE, and SwiGLU
- CPU inference - No GPU required, designed for consumer-grade commodity hardware on local machines or cloud compute instances
- Memory efficient - Run on 16GB+ RAM using memory-mapped weights via Java Foreign Memory API
- MXFP4 dequantization - Handles original MXFP4 quantized MoE weights
- Performance optimized - Support KVCache and exploit modern JDK GC/JIT, parallel processing, SIMD Vector API, and fused operations
- Educational - Clean, readable code for understanding LLM Transformer internals
- Handy CLI - Interactive chat and single-shot generation modes
- Java 23+
- Minimum 16GB memory, ideally 24GB+
Download the gpt-oss model weights from Hugging Face Hub. Start with gpt-oss-20b as it runs efficiently on consumer CPU-based hardware.
pip install -U "huggingface_hub[cli]"
# gpt-oss-20b
huggingface-cli download openai/gpt-oss-20b --include "original/*" --local-dir gpt-oss-20b/
# gpt-oss-120b
huggingface-cli download openai/gpt-oss-120b --include "original/*" --local-dir gpt-oss-120b/Build the project to generate the executable JAR located at build/libs/gpt-oss-java-1.0.0-all.jar.
./gradlew build shadowJarNote: you can download JDK and configure the Java version using either method below:
- Create a
gradle.propertiesand addorg.gradle.java.home=/path/to/jdk-23+. - Set the environment variable
export JAVA_HOME=/path/to/jdk-23+
java --add-modules jdk.incubator.vector -jar build/libs/gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss/original/model.safetensorsCommand Line Options:
Usage: java GPTOSSCli <model_path> [options]
Examples:
java --add-modules jdk.incubator.vector -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss/model.safetensors -m generate -p "Hello world" -n 50
java --add-modules jdk.incubator.vector -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss/model.safetensors -m chat -t 0.1
java --add-modules jdk.incubator.vector -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss/model.safetensors -t 0.2 -n 32768 --multi-turn
Options:
-m <mode> Inference mode: 'generate' (single shot) | 'chat' (interactive multi-turn) [default: chat]
-p <prompt> Input prompt (required for generate mode)
-n <tokens> Maximum tokens to generate [default: 100]
-t <temperature> Sampling temperature (0 to inf) [default: 0.1]
-s <ids> Stop token IDs (comma-separated) [default: 0,199999,200002]
--debug Enable debug logging [default: false]
--multi-turn Enable multi-turn conversation (chat mode only) [default: false]
--model-size gpt-oss model size 20b or 120b [default: 20b]Examples:
# Interactive chat (default if -m not set)
java --add-modules jdk.incubator.vector \
-jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-20b/original/model.safetensors \
-m chat
# Keeps conversation history in one session
java --add-modules jdk.incubator.vector \
-jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-20b/original/model.safetensors \
-m chat \
--multi-turn
# Single-shot generation with max of 100 tokens and temperature of 0.2
java --add-modules jdk.incubator.vector \
-jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-20b/original/model.safetensors \
-m generate \
-t 0.2 \
-p "Why do people use umbrellas when it rains?" \
-n 100
# Override stop IDs (default: 0,199999,200002)
java --add-modules jdk.incubator.vector \
-jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-20b/original/model.safetensors \
-m generate \
-p "Write a short story about AI." \
-s 3392
# Debug logging to show performance metrics
java --add-modules jdk.incubator.vector \
-jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-20b/original/model.safetensors \
-m generate \
-p "Explain TCP vs UDP" \
--debug
# Switch to 120B
java --add-modules jdk.incubator.vector \
-jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-120b/original/model.safetensors \
--model-size 120BControl matrix multiplication and scaled dot product parallelism by setting the ForkJoinPool size. By default, it uses all available vCPU cores.
# Use 16 threads
java -Djava.util.concurrent.ForkJoinPool.common.parallelism=16 --add-modules jdk.incubator.vector \
-jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss/model.safetensors \
-m chatYou can specify -Xmx in JVM options, normally it requires about 16GB JVM heap memory.
The memory-mapped MLP weights require an additional ~8GB of system memory.
KV cache allocation goes with the max-tokens -n CLI parameter, default a lower bound of 4096 tokens.
This Java implementation delivers the following CPU inference performance on gpt-oss-20b:
- Apple M3 Pro (12 cores, 36GB RAM):
- Decode: avg 8.7 tokens/sec
- Prefill: avg 11.8 tokens/sec
- AWS EC2 m5.4xlarge (Intel Xeon Platinum 8175M, 8 physical cores, 16 vCPUs, 64GB RAM)
- Decode: avg 6.8 tokens/sec
- Prefill: avg 10 tokens/sec
For detailed benchmark results, hardware, and performance data, see benchmark/README.md.
This project is licensed under the Apache-2.0 License.