Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 84 additions & 61 deletions llm-runtime/kllama/src/jvmMain/kotlin/sk/ainet/apps/kllama/cli/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@ package sk.ainet.apps.kllama.cli

import sk.ainet.apps.kllama.GGUFTokenizer
import sk.ainet.models.llama.LlamaConfigParser
import sk.ainet.apps.kllama.LlamaIngestion
import sk.ainet.apps.kllama.LlamaLoadConfig
import sk.ainet.apps.llm.OptimizedLLMMode
import sk.ainet.apps.llm.OptimizedLLMRuntime
import sk.ainet.apps.llm.Tokenizer
import sk.ainet.apps.llm.tokenizer.TokenizerFactory
import sk.ainet.models.llama.DecoderGgufMemSegConverter
import sk.ainet.models.llama.DecoderSafeTensorsLoader
import sk.ainet.models.llama.LlamaNetworkLoader
import sk.ainet.models.llama.LlamaRuntime
import sk.ainet.apps.kllama.CpuAttentionBackend
import sk.ainet.apps.kllama.Llama2DotCWeightLoader
import sk.ainet.models.llama.MemSegWeightConverter
import sk.ainet.models.qwen.QwenNetworkLoader
import sk.ainet.apps.kllama.TokenizerUtils
import sk.ainet.apps.llm.backend.BackendRegistry
Expand Down Expand Up @@ -42,7 +41,6 @@ import sk.ainet.apps.llm.InferenceRuntime
import sk.ainet.apps.llm.generate
import sk.ainet.io.gguf.StreamingGGUFReader
import sk.ainet.models.llama.DecoderGgufWeightLoader
import sk.ainet.models.llama.LlamaWeightMapper

private enum class ModelFormat { GGUF, SAFETENSORS, BIN }

Expand Down Expand Up @@ -404,62 +402,86 @@ fun main(args: Array<String>) {
bos = convertedWeights.metadata.bosTokenId,
)
eosTokenId = convertedWeights.metadata.eosTokenId
} else if (format == ModelFormat.GGUF) {
// --- Llama / Mistral GGUF: DSL path. Mirrors the Qwen branch
// above. LlamaNetworkLoader builds llamaNetwork() (RoPE
// INTERLEAVED, no QK-norm — the LLaMA family default).
val loader = DecoderGgufWeightLoader(
randomAccessProvider = { JvmRandomAccessSource.open(modelPath.toString()) },
quantPolicy = QuantPolicy.NATIVE_OPTIMIZED,
acceptedArchitectures = LLAMA_COMPATIBLE_ARCHITECTURES,
)
println("Loading GGUF model from $modelPath (Llama, DSL streaming mode)...")
val rawWeights = loader.loadToMapStreaming<FP32, Float>(ctx)

val convertedWeights = if (rawWeights.quantTypes.isNotEmpty()) {
println("Converting ${rawWeights.quantTypes.size} quantized tensors to MemorySegment-backed SIMD format...")
DecoderGgufMemSegConverter.convert(rawWeights, ctx, quantArena)
} else {
rawWeights
}

if (cliArgs.contextLength != null) {
println("Context length capped to ${cliArgs.contextLength} (model default: ${convertedWeights.metadata.contextLength})")
}
val llamaModel = LlamaNetworkLoader.fromWeights(convertedWeights)
runtime = OptimizedLLMRuntime(
model = llamaModel,
ctx = ctx,
mode = OptimizedLLMMode.DIRECT,
dtype = FP32::class,
bos = convertedWeights.metadata.bosTokenId,
)
eosTokenId = convertedWeights.metadata.eosTokenId
binVocabSize = convertedWeights.metadata.vocabSize
} else if (format == ModelFormat.SAFETENSORS) {
// --- Llama SafeTensors: DSL path via DecoderSafeTensorsLoader
// (HF tensor names → GGUF-canonical names, BF16/F16 → FP32).
val modelDir = resolveModelDir(modelPath)
val safetensorsFile = if (modelPath.isDirectory()) {
modelDir.resolve("model.safetensors")
} else {
modelPath
}
val configFile = modelDir.resolve("config.json")
if (!configFile.exists()) error("config.json not found in $modelDir")

println("Loading SafeTensors model from $safetensorsFile...")
val configJson = configFile.readText()
val safeMetadata = LlamaConfigParser.parse(configJson)
val tiedEmbeddings = LlamaConfigParser.isTiedEmbeddings(configJson)
println(" Architecture: ${safeMetadata.architecture}, layers=${safeMetadata.blockCount}, " +
"dim=${safeMetadata.embeddingLength}, heads=${safeMetadata.headCount}, " +
"kvHeads=${safeMetadata.kvHeadCount}, vocab=${safeMetadata.vocabSize}")
if (tiedEmbeddings) println(" Tied word embeddings: output.weight = embed_tokens.weight")

val safeLoader = DecoderSafeTensorsLoader<FP32>(ctx, FP32::class, safeMetadata, tiedEmbeddings)
val safeWeights = safeLoader.loadToMap {
JvmRandomAccessSource.open(safetensorsFile.toString())
}

if (cliArgs.contextLength != null) {
println("Context length capped to ${cliArgs.contextLength} (model default: ${safeWeights.metadata.contextLength})")
}
val llamaModel = LlamaNetworkLoader.fromWeights(safeWeights)
runtime = OptimizedLLMRuntime(
model = llamaModel,
ctx = ctx,
mode = OptimizedLLMMode.DIRECT,
dtype = FP32::class,
bos = safeWeights.metadata.bosTokenId,
)
eosTokenId = safeWeights.metadata.eosTokenId
binVocabSize = safeWeights.metadata.vocabSize
} else {
// --- Llama / SafeTensors / BIN: legacy LlamaRuntime path ---
val runtimeWeights = when (format) {
ModelFormat.GGUF -> {
val ingestion = LlamaIngestion<FP32>(
ctx = ctx,
dtype = FP32::class,
config = LlamaLoadConfig(
quantPolicy = QuantPolicy.NATIVE_OPTIMIZED,
allowQuantized = true,
acceptedArchitectures = LLAMA_COMPATIBLE_ARCHITECTURES
)
)
println("Loading GGUF model from $modelPath (streaming mode)...")
val rawWeights = ingestion.loadStreaming {
JvmRandomAccessSource.open(modelPath.toString())
}
if (rawWeights.quantTypes.isNotEmpty()) {
println("Converting ${rawWeights.quantTypes.size} quantized tensors to MemorySegment-backed SIMD format...")
MemSegWeightConverter.convert(rawWeights, ctx, quantArena)
} else {
rawWeights
}
}
ModelFormat.SAFETENSORS -> {
val modelDir = resolveModelDir(modelPath)
val safetensorsFile = if (modelPath.isDirectory()) {
modelDir.resolve("model.safetensors")
} else {
modelPath
}
val configFile = modelDir.resolve("config.json")
if (!configFile.exists()) error("config.json not found in $modelDir")

println("Loading SafeTensors model from $safetensorsFile...")
val configJson = configFile.readText()
val metadata = LlamaConfigParser.parse(configJson)
val tiedEmbeddings = LlamaConfigParser.isTiedEmbeddings(configJson)
println(" Architecture: ${metadata.architecture}, layers=${metadata.blockCount}, " +
"dim=${metadata.embeddingLength}, heads=${metadata.headCount}, " +
"kvHeads=${metadata.kvHeadCount}, vocab=${metadata.vocabSize}")
if (tiedEmbeddings) println(" Tied word embeddings: output.weight = embed_tokens.weight")

val ingestion = LlamaIngestion<FP32>(ctx = ctx, dtype = FP32::class)
ingestion.loadSafeTensors(
randomAccessProvider = { JvmRandomAccessSource.open(safetensorsFile.toString()) },
metadata = metadata,
tiedEmbeddings = tiedEmbeddings
)
}
ModelFormat.BIN -> {
println("Loading Karpathy .bin model from $modelPath...")
modelPath.inputStream().use { input ->
Llama2DotCWeightLoader.load(ctx, input.asSource().buffered())
}
}
// --- BIN (Karpathy llama2.c format): legacy LlamaRuntime path.
// The .bin loader returns LlamaRuntimeWeights directly; the DSL
// path requires DecoderGgufWeights, so this format stays on
// legacy until either Llama2DotCWeightLoader is migrated or
// .bin support is dropped.
println("Loading Karpathy .bin model from $modelPath...")
val runtimeWeights = modelPath.inputStream().use { input ->
Llama2DotCWeightLoader.load(ctx, input.asSource().buffered())
}

if (cliArgs.contextLength != null) {
Expand All @@ -468,11 +490,12 @@ fun main(args: Array<String>) {
val backend = CpuAttentionBackend<FP32>(
ctx, runtimeWeights, FP32::class,
ropeFreqBase = runtimeWeights.metadata.ropeFreqBase,
maxContextLength = cliArgs.contextLength
maxContextLength = cliArgs.contextLength,
)
@Suppress("DEPRECATION")
runtime = LlamaRuntime<FP32>(
ctx, runtimeWeights, backend, FP32::class,
eps = runtimeWeights.metadata.rmsNormEps
eps = runtimeWeights.metadata.rmsNormEps,
)
eosTokenId = runtimeWeights.metadata.eosTokenId
binVocabSize = runtimeWeights.metadata.vocabSize
Expand Down