Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ Download `FP16` quantized `Qwen3` .gguf files from:
- https://huggingface.co/ggml-org/Qwen3-4B-GGUF
- https://huggingface.co/ggml-org/Qwen3-8B-GGUF

Download `FP16` quantized `Qwen2.5` .gguf files from:
- https://huggingface.co/bartowski/Qwen2.5-0.5B-Instruct-GGUF
- https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct-GGUF

Download `FP16` quantized `DeepSeek-R1-Distill-Qwen` .gguf files from:
- https://huggingface.co/hdnh2006/DeepSeek-R1-Distill-Qwen-1.5B-GGUF

Please be gentle with [huggingface.co](https://huggingface.co) servers:

**Note** FP16 models are first-class citizens for the current version.
Expand Down Expand Up @@ -274,6 +281,15 @@ wget https://huggingface.co/ggml-org/Qwen3-0.6B-GGUF/resolve/main/Qwen3-8B-f16.g

# Phi-3-mini-4k - FP16
wget https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-fp16.gguf

# Qwen2.5 (0.5B)
wget https://huggingface.co/bartowski/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/Qwen2.5-0.5B-Instruct-f16.gguf

# Qwen2.5 (1.5B)
wget https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct-GGUF/resolve/main/qwen2.5-1.5b-instruct-fp16.gguf

# DeepSeek-R1-Distill-Qwen (1.5B)
wget https://huggingface.co/hdnh2006/DeepSeek-R1-Distill-Qwen-1.5B-GGUF/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-F16.gguf
```

**[Experimental]** you can download the Q8 and Q4 used in the original implementation of Llama3.java, but for now are going to be dequanted to FP16 for TornadoVM support:
Expand Down
134 changes: 134 additions & 0 deletions src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
import org.beehive.gpullama3.inference.state.Phi3State;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights;
import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;

import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

import java.lang.foreign.MemorySegment;
Expand Down Expand Up @@ -176,6 +179,137 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
return state.logits;
}

public static FloatTensor forwardJavaQwen2(Model model, State state, int token, int position) {
final Qwen2Configuration config = (Qwen2Configuration) model.configuration();
final Qwen2StandardWeights weights = (Qwen2StandardWeights) model.weights();
int dim = config.dim();
int headSize = config.headSize();
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery
float sqrtHeadSize = (float) Math.sqrt(headSize);

weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);

// forward all the layers
for (int l = 0; l < config.numberOfLayers(); l++) {
// attention rmsnorm
final int curLayer = l;
rmsnorm(state.xb, state.x, weights.rms_att_weight[curLayer], 0, dim, config.rmsNormEps());

// qkv matmuls for this position
weights.wq[l].matmul(state.xb, state.q, dim, dim);
weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
weights.wv[l].matmul(state.xb, state.v, kvDim, dim);

// qkv additions with qkv bias
state.q.addInPlace(weights.q_bias[curLayer]);
state.k.addInPlace(weights.k_bias[curLayer]);
state.v.addInPlace(weights.v_bias[curLayer]);

// RoPE relative positional encoding: complex-valued rotate q and k in each head
// GPT-NeoX style RoPE, real/imaginary components are stored with a headSize/2 offset per head, instead of consecutive.
for (int h = 0; h < config.numberOfHeads(); ++h) {
int rotn = h < config.numberOfKeyValueHeads() ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
int poffset = h * headSize;
for (int i0 = 0; i0 < headSize; i0 += 2) {
int ic = i0 / 2;
float fcr = weights.freq_cis_real.getFloat((position) * (headSize / 2) + ic);
float fci = weights.freq_cis_imag.getFloat((position) * (headSize / 2) + ic);
for (int vi = 0; vi < rotn; vi++) {
FloatTensor vec = (vi == 0) ? state.q : state.k; // the vector to rotate (query or key)
float v0 = vec.getFloat(poffset + ic);
float v1 = vec.getFloat(poffset + ic + headSize/2);
vec.setFloat(poffset + ic, v0 * fcr - v1 * fci);
vec.setFloat(poffset + ic + headSize/2, v0 * fci + v1 * fcr);
}
}
}

// save key,value at this time step (position) to our kv cache
//int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience
state.k.copyTo(0, state.keyCache[curLayer], position * kvDim, kvDim);
state.v.copyTo(0, state.valueCache[curLayer], position * kvDim, kvDim);

// multihead attention. iterate over all heads
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
// get the query vector for this head
// float* q = s.q + h * headSize;
int qOffset = h * headSize;

// attention scores for this head
// float* att = s.att + h * config.seq_len;
int attOffset = h * config.contextLength();

// iterate over all timesteps, including the current one
for (int t = 0; t <= position; t++) {
// get the key vector for this head and at this timestep
// float* k = s.key_cache + loff + t * dim + h * headSize;
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// calculate the attention score as the dot product of q and k
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
score /= sqrtHeadSize;
// save the score to the attention buffer
state.att.setFloat(attOffset + t, score);
}

// softmax the scores to get attention weights, from 0..position inclusively
state.att.softmaxInPlace(attOffset, position + 1);

// weighted sum of the values, store back into xb
// float* xb = s.xb + h * headSize;
int xbOffset = h * headSize;
// memset(xb, 0, headSize * sizeof(float));
state.xb.fillInPlace(xbOffset, headSize, 0f);

for (int t = 0; t <= position; t++) {
// get the value vector for this head and at this timestep
// float* v = s.value_cache + loff + t * dim + h * headSize;C
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
// get the attention weight for this timestep
float a = state.att.getFloat(attOffset + t);
// accumulate the weighted value into xb
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
}
});

// final matmul to get the output of the attention
weights.wo[l].matmul(state.xb, state.xb2, dim, dim);

// residual connection back into x
state.x.addInPlace(state.xb2);

// ffn rmsnorm
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[curLayer], 0, dim, config.rmsNormEps());

// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);

// SwiGLU non-linearity
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));

// elementwise multiply with w3(x)
state.hb.multiplyInPlace(state.hb2);

// final matmul to get the output of the ffn
weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());

// residual connection
state.x.addInPlace(state.xb);

}

// final rmsnorm
rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());

// classifier into logits
weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);

return state.logits;
}

public static FloatTensor forwardJavaQwen3(Model model, State state, int token, int position) {
// a few convenience variables
final Qwen3Configuration config = (Qwen3Configuration) model.configuration();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package org.beehive.gpullama3.inference.state;

import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

import java.util.stream.Stream;

public class Qwen2State extends State {

public Qwen2State(Configuration config, int batchsize) {
super(config, batchsize);
this.localSize = 32;
}
@Override
protected StateFields createStateFields(Configuration configuration) {
StateFields fields = new StateFields();

Qwen2Configuration config = (Qwen2Configuration) configuration;

int nEmbdGqa = config.kvDim();

// with Qwen2-specific sizes
fields.x = ArrayFloatTensor.allocate(config.dim());
fields.xb = ArrayFloatTensor.allocate(config.dim());
fields.xb2 = ArrayFloatTensor.allocate(config.dim());
fields.hb = ArrayFloatTensor.allocate(config.hiddenDim());
fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim());
fields.q = ArrayFloatTensor.allocate(config.dim());
fields.k = ArrayFloatTensor.allocate(config.kvDim());
fields.v = ArrayFloatTensor.allocate(config.kvDim());
fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength());
fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());

// Key-value cache with Qwen2 dimensions
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);

// TornadoVM wrappers with Qwen2 dimensions
fields.wrapX = new FloatArray(config.dim());
fields.wrapXb = new FloatArray(config.dim());
fields.wrapXb2 = new FloatArray(config.dim());
fields.wrapHb = new FloatArray(config.hiddenDim());
fields.wrapHb2 = new FloatArray(config.hiddenDim());

fields.wrapLogits = new FloatArray(config.vocabularySize());
fields.wrapQ = new FloatArray(config.dim());
fields.wrapK = new FloatArray(config.kvDim());
fields.wrapV = new FloatArray(config.kvDim());

fields.wrapKeyCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
fields.wrapValueCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
fields.wrapValueCache.init(0.f);
fields.wrapKeyCache.init(0.f);
fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength());
fields.positionHolder = new IntArray(1);

// Temporary arrays
fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));

return fields;

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package org.beehive.gpullama3.inference.weights.standard;

import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
import org.beehive.gpullama3.inference.weights.Weights;

public class Qwen2StandardWeights extends StandardWeights {
// Qwen2-specific weights
public final FloatTensor[] q_bias, k_bias, v_bias;

public Qwen2StandardWeights(
FloatTensor token_embedding_table,
FloatTensor[] rms_att_weight,
FloatTensor[] wq,
FloatTensor[] wk,
FloatTensor[] wv,
FloatTensor[] q_bias,
FloatTensor[] k_bias,
FloatTensor[] v_bias,
FloatTensor[] wo,
FloatTensor[] rms_ffn_weight,
FloatTensor[] w1,
FloatTensor[] w2,
FloatTensor[] w3,
FloatTensor rms_final_weight,
ArrayFloatTensor freq_cis_real,
ArrayFloatTensor freq_cis_imag,
FloatTensor wcls,
GGMLType weightType) {
// call to StandardWeights constructor
super(token_embedding_table,
rms_att_weight,
wq,
wk,
wv,
wo,
rms_ffn_weight,
w1,
w2,
w3,
rms_final_weight,
freq_cis_real,
freq_cis_imag,
wcls,
weightType);
// init Qwen2-specific fields
this.q_bias = q_bias;
this.k_bias = k_bias;
this.v_bias = v_bias;
}

@Override
public GGMLType getWeightType() {
return weightType;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package org.beehive.gpullama3.inference.weights.tornado;

import org.beehive.gpullama3.core.model.GGMLType;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;

public class Qwen2TornadoWeights extends TornadoWeights {

// Qwen2-specific tornado weights
public FloatArray[] q_biasLayered;
public FloatArray[] k_biasLayered;
public FloatArray[] v_biasLayered;

public Qwen2TornadoWeights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered,
FloatArray[] wqBiasLayered,
FloatArray[] wkBiasLayered,
FloatArray[] wvBiasLayered,
HalfFloatArray[] woLayered, FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered,
HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered, FloatArray rms_final_weight_as_floatArray, FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray,
GGMLType weightType) {
// call to TornadoWeights constructor
super(tokenEmbeddingTable,
rms_att_weightLayered,
wqLayered,
wkLayered,
wvLayered,
woLayered,
rms_ffn_weightLayered,
w1Layered,
w2Layered,
w3Layered,
rms_final_weight_as_floatArray,
freq_cis_realFlat,
freq_cis_imagFlat,
wclsByteArray,
weightType);
// init qwen2-specific fields
this.q_biasLayered = wqBiasLayered;
this.k_biasLayered = wkBiasLayered;
this.v_biasLayered = wvBiasLayered;
}
}
Loading