Skip to content

A pure Java implementation of OpenAI's gpt-oss inference optimized for CPU execution

License

amzn/gpt-oss.java

GPT-OSS Java Inference

A pure Java implementation of OpenAI's gpt-oss inference in ~1000 lines of code optimized for CPU execution.

Inspired by llama.cpp, llama2.c, this repo ports the gpt-oss PyTorch model.py to efficient Java code, emphasizing minimalism, simplicity, and educational purpose.

Key Features

  • Pure Java - No native dependencies, runs anywhere with Java 23+
  • Complete gpt-oss architecture - Full implementation of MoE transformer with GQA, sliding window attention, RoPE, and SwiGLU
  • CPU inference - No GPU required, designed for consumer-grade commodity hardware on local machines or cloud compute instances
  • Memory efficient - Run on 16GB+ RAM using memory-mapped weights via Java Foreign Memory API
  • MXFP4 dequantization - Handles original MXFP4 quantized MoE weights
  • Performance optimized - Support KVCache and exploit modern JDK GC/JIT, parallel processing, SIMD Vector API, and fused operations
  • Educational - Clean, readable code for understanding LLM Transformer internals
  • Handy CLI - Interactive chat and single-shot generation modes

Requirements

  • Java 23+
  • Minimum 16GB memory, ideally 24GB+

Quick Start

1. Model Preparation

Download the gpt-oss model weights from Hugging Face Hub. Start with gpt-oss-20b as it runs efficiently on consumer CPU-based hardware.

pip install -U "huggingface_hub[cli]"

# gpt-oss-20b
huggingface-cli download openai/gpt-oss-20b --include "original/*" --local-dir gpt-oss-20b/

# gpt-oss-120b
huggingface-cli download openai/gpt-oss-120b --include "original/*" --local-dir gpt-oss-120b/

2. Build

Build the project to generate the executable JAR located at build/libs/gpt-oss-java-1.0.0-all.jar.

./gradlew build shadowJar

Note: you can download JDK and configure the Java version using either method below:

  1. Create a gradle.properties and add org.gradle.java.home=/path/to/jdk-23+.
  2. Set the environment variable
export JAVA_HOME=/path/to/jdk-23+

3. Run

java --add-modules jdk.incubator.vector -jar build/libs/gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss/original/model.safetensors

Command Line Options:

Usage: java GPTOSSCli <model_path> [options]

Examples:
  java --add-modules jdk.incubator.vector -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss/model.safetensors -m generate -p "Hello world" -n 50
  java --add-modules jdk.incubator.vector -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss/model.safetensors -m chat -t 0.1
  java --add-modules jdk.incubator.vector -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss/model.safetensors -t 0.2 -n 32768 --multi-turn

Options:
  -m <mode>        Inference mode: 'generate' (single shot) | 'chat' (interactive multi-turn) [default: chat]
  -p <prompt>      Input prompt (required for generate mode)
  -n <tokens>      Maximum tokens to generate [default: 100]
  -t <temperature> Sampling temperature (0 to inf) [default: 0.1]
  -s <ids>         Stop token IDs (comma-separated) [default: 0,199999,200002]
  --debug          Enable debug logging [default: false]
  --multi-turn     Enable multi-turn conversation (chat mode only) [default: false]
  --model-size     gpt-oss model size 20b or 120b [default: 20b]

Examples:

# Interactive chat (default if -m not set)
java --add-modules jdk.incubator.vector \
  -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-20b/original/model.safetensors \
  -m chat

# Keeps conversation history in one session
java --add-modules jdk.incubator.vector \
  -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-20b/original/model.safetensors \
  -m chat \
  --multi-turn

# Single-shot generation with max of 100 tokens and temperature of 0.2
java --add-modules jdk.incubator.vector \
  -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-20b/original/model.safetensors \
  -m generate \
  -t 0.2 \
  -p "Why do people use umbrellas when it rains?" \
  -n 100

# Override stop IDs (default: 0,199999,200002)
java --add-modules jdk.incubator.vector \
  -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-20b/original/model.safetensors \
  -m generate \
  -p "Write a short story about AI." \
  -s 3392

# Debug logging to show performance metrics
java --add-modules jdk.incubator.vector \
  -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-20b/original/model.safetensors \
  -m generate \
  -p "Explain TCP vs UDP" \
  --debug

# Switch to 120B
java --add-modules jdk.incubator.vector \
  -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss-120b/original/model.safetensors \
  --model-size 120B

Tuning knobs

Thread and memory configuration

Control matrix multiplication and scaled dot product parallelism by setting the ForkJoinPool size. By default, it uses all available vCPU cores.

# Use 16 threads
java -Djava.util.concurrent.ForkJoinPool.common.parallelism=16 --add-modules jdk.incubator.vector \
  -jar gpt-oss-java-1.0.0-all.jar /path/to/gpt-oss/model.safetensors \
  -m chat

You can specify -Xmx in JVM options, normally it requires about 16GB JVM heap memory. The memory-mapped MLP weights require an additional ~8GB of system memory.

KV Cache

KV cache allocation goes with the max-tokens -n CLI parameter, default a lower bound of 4096 tokens.

Performance Benchmarks

This Java implementation delivers the following CPU inference performance on gpt-oss-20b:

  • Apple M3 Pro (12 cores, 36GB RAM):
    • Decode: avg 8.7 tokens/sec
    • Prefill: avg 11.8 tokens/sec
  • AWS EC2 m5.4xlarge (Intel Xeon Platinum 8175M, 8 physical cores, 16 vCPUs, 64GB RAM)
    • Decode: avg 6.8 tokens/sec
    • Prefill: avg 10 tokens/sec

For detailed benchmark results, hardware, and performance data, see benchmark/README.md.

References

License

This project is licensed under the Apache-2.0 License.

About

A pure Java implementation of OpenAI's gpt-oss inference optimized for CPU execution

Resources

License

Code of conduct

Contributing

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published