Skip to content

Commit cd425a0

Browse files
michalharakalclaude
andcommitted
Remove Karpathy .bin format support, simplify to GGUF-only
- Remove Format enum and format parameter from LlamaWeightLoader - Remove loadFromKarpathyBin() and all Karpathy-specific helper methods - Update loadLlamaRuntimeWeights functions to remove format parameter - Simplify CLI to only support .gguf files with embedded tokenizer - Fix Q8_1 dequantization bug (bytesPerBlock was 40, should be 36) - Update tests to use GGUF format tensor shapes Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent ee99c6e commit cd425a0

10 files changed

Lines changed: 86 additions & 559 deletions

File tree

skainet-apps/skainet-kllama/src/commonMain/kotlin/sk/ainet/apps/kllama/LlamaIngestion.kt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@ import sk.ainet.io.gguf.llama.LlamaWeightLoader
77
import sk.ainet.io.gguf.llama.loadLlamaRuntimeWeights
88

99
/**
10-
* Thin facade around the GGUF/Karpathy loader that sets sensible defaults for the KLLama app.
10+
* Thin facade around the GGUF loader that sets sensible defaults for the KLLama app.
1111
* Default policy dequantizes to FP32 to ensure parity before quant-aware kernels are wired.
1212
*/
1313
public data class LlamaLoadConfig(
14-
val format: LlamaWeightLoader.Format = LlamaWeightLoader.Format.GGUF,
1514
val quantPolicy: LlamaWeightLoader.QuantPolicy = LlamaWeightLoader.QuantPolicy.DEQUANTIZE_TO_FP32,
1615
val allowQuantized: Boolean = false
1716
)
@@ -21,7 +20,7 @@ public class LlamaIngestion(
2120
private val config: LlamaLoadConfig = LlamaLoadConfig()
2221
) {
2322
/**
24-
* Load LLaMA runtime weights from the provided source (GGUF by default).
23+
* Load LLaMA runtime weights from the provided GGUF source.
2524
*
2625
* @throws IllegalStateException if metadata/tensors are missing or quantized tensors are present
2726
* when [config.allowQuantized] is false.
@@ -30,7 +29,6 @@ public class LlamaIngestion(
3029
return loadLlamaRuntimeWeights(
3130
ctx = ctx,
3231
sourceProvider = sourceProvider,
33-
format = config.format,
3432
quantPolicy = config.quantPolicy,
3533
allowQuantized = config.allowQuantized
3634
)

skainet-apps/skainet-kllama/src/jvmMain/kotlin/sk/ainet/apps/kllama/cli/Main.kt

Lines changed: 16 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@ import sk.ainet.apps.kllama.GGUFTokenizer
1313
import sk.ainet.apps.kllama.LlamaIngestion
1414
import sk.ainet.apps.kllama.LlamaLoadConfig
1515
import sk.ainet.apps.kllama.Tokenizer
16-
import sk.ainet.apps.kllama.TokenizerUtils
1716
import sk.ainet.apps.kllama.LlamaRuntime
1817
import sk.ainet.context.DirectCpuExecutionContext
1918
import sk.ainet.io.gguf.llama.LlamaWeightLoader
2019

2120
private fun usage(): Nothing {
22-
println("Usage: kllama <model-path> <prompt> [tokenizer-path] [steps=64] [temperature=0.8]")
23-
println(" For GGUF models, tokenizer-path is optional (uses embedded tokenizer)")
21+
println("Usage: kllama <model.gguf> <prompt> [steps=64] [temperature=0.8]")
2422
exitProcess(1)
2523
}
2624

@@ -30,78 +28,47 @@ fun main(args: Array<String>) {
3028

3129
val modelPath = Path.of(args[0])
3230
val prompt = args[1]
33-
34-
// Parse remaining args: tokenizer-path is optional for GGUF
35-
var tokenizerPath: Path? = null
36-
var steps = 64
37-
var temperature = 0.8f
38-
39-
// Check if args[2] is a file path or a number (steps)
40-
if (args.size > 2) {
41-
val arg2 = args[2]
42-
if (arg2.toIntOrNull() != null) {
43-
// It's steps
44-
steps = arg2.toInt()
45-
temperature = args.getOrNull(3)?.toFloatOrNull() ?: 0.8f
46-
} else {
47-
// It's tokenizer path
48-
tokenizerPath = Path.of(arg2)
49-
steps = args.getOrNull(3)?.toIntOrNull() ?: 64
50-
temperature = args.getOrNull(4)?.toFloatOrNull() ?: 0.8f
51-
}
52-
}
31+
val steps = args.getOrNull(2)?.toIntOrNull() ?: 64
32+
val temperature = args.getOrNull(3)?.toFloatOrNull() ?: 0.8f
5333

5434
if (!modelPath.exists()) error("Model not found: $modelPath")
5535

56-
val format = when (modelPath.extension.lowercase()) {
57-
"gguf" -> LlamaWeightLoader.Format.GGUF
58-
"bin" -> LlamaWeightLoader.Format.KARPATHY_BIN
59-
else -> error("Unknown model extension: ${modelPath.extension}. Use .gguf or .bin")
60-
}
61-
62-
// For .bin format, tokenizer is required
63-
if (format == LlamaWeightLoader.Format.KARPATHY_BIN && tokenizerPath == null) {
64-
error("Tokenizer path is required for .bin format models")
65-
}
66-
if (tokenizerPath != null && !tokenizerPath.exists()) {
67-
error("Tokenizer not found: $tokenizerPath")
36+
if (modelPath.extension.lowercase() != "gguf") {
37+
error("Only GGUF format is supported. Use a .gguf model file.")
6838
}
6939

7040
val ctx = DirectCpuExecutionContext()
7141
val ingestion = LlamaIngestion(
7242
ctx = ctx,
7343
config = LlamaLoadConfig(
74-
format = format,
7544
quantPolicy = LlamaWeightLoader.QuantPolicy.DEQUANTIZE_TO_FP32,
7645
allowQuantized = false
7746
)
7847
)
7948

49+
println("Loading model from $modelPath...")
8050
val runtimeWeights = ingestion.load {
8151
Files.newInputStream(modelPath).asSource().buffered()
8252
}
8353
val runtime = LlamaRuntime(ctx, runtimeWeights)
8454

85-
// Load tokenizer: use embedded GGUF tokenizer if no external path provided
86-
val tokenizer: Tokenizer = if (tokenizerPath != null) {
87-
loadTokenizer(tokenizerPath, runtimeWeights.metadata.vocabSize)
88-
} else {
89-
println("Using embedded GGUF tokenizer...")
90-
GGUFTokenizer.fromSource(Files.newInputStream(modelPath).asSource().buffered())
91-
}
55+
// Load embedded GGUF tokenizer
56+
println("Loading embedded GGUF tokenizer...")
57+
val tokenizer: Tokenizer = GGUFTokenizer.fromSource(Files.newInputStream(modelPath).asSource().buffered())
9258

9359
val promptTokens = tokenizer.encode(prompt)
9460

61+
println("Generating $steps tokens with temperature=$temperature...")
62+
println("---")
63+
9564
val elapsed = measureTime {
9665
runtime.generate(prompt = promptTokens, steps = steps, temperature = temperature) { id ->
9766
print(tokenizer.decode(id))
9867
}
9968
}.inWholeMilliseconds
100-
println("\n\ntok/s: ${steps / elapsed.toDouble() * 1000}")
101-
}
102-
}
10369

104-
private fun loadTokenizer(path: Path, vocabSize: Int): Tokenizer {
105-
val source = Files.newInputStream(path).asSource().buffered()
106-
return TokenizerUtils.buildTokenizer(source, vocabSize)
70+
val tokPerSec = steps / elapsed.toDouble() * 1000
71+
println("\n---")
72+
println("tok/s: $tokPerSec")
73+
}
10774
}

skainet-apps/skainet-kllama/src/jvmTest/kotlin/sk/ainet/apps/kllama/LlamaIngestionTest.kt

Lines changed: 0 additions & 54 deletions
This file was deleted.

skainet-apps/skainet-kllama/src/jvmTest/kotlin/sk/ainet/apps/kllama/LlamaRuntimeTest.kt

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,16 @@ class LlamaRuntimeTest {
2121
val seqLen = 4
2222
val vocab = 3
2323

24+
// GGUF format shapes:
25+
// - attention weights: [dim, dim] or [dim, kv_dim]
26+
// - ffn gate/up: [dim, ff_dim]
27+
// - ffn down: [ff_dim, dim]
28+
// - token embedding: [dim, vocab]
29+
// - output weight: [dim, vocab]
2430
val ones1d = ctx.full<FP32, Float>(Shape(dim), FP32::class, 1f)
2531
val ones2d = ctx.full<FP32, Float>(Shape(dim, dim), FP32::class, 0.25f)
26-
val gate = ctx.full<FP32, Float>(Shape(hidden, dim), FP32::class, 0.1f)
27-
val down = ctx.full<FP32, Float>(Shape(dim, hidden), FP32::class, 0.05f)
32+
val gateUp = ctx.full<FP32, Float>(Shape(dim, hidden), FP32::class, 0.1f) // [dim, ff_dim]
33+
val down = ctx.full<FP32, Float>(Shape(hidden, dim), FP32::class, 0.05f) // [ff_dim, dim]
2834
val ropeReal = ctx.full<FP32, Float>(Shape(seqLen, headSize / 2), FP32::class, 1f)
2935
val ropeImag = ctx.full<FP32, Float>(Shape(seqLen, headSize / 2), FP32::class, 0f)
3036

@@ -35,9 +41,9 @@ class LlamaRuntimeTest {
3541
wv = ones2d,
3642
wo = ones2d,
3743
ffnNorm = ones1d,
38-
ffnGate = gate,
44+
ffnGate = gateUp,
3945
ffnDown = down,
40-
ffnUp = gate
46+
ffnUp = gateUp
4147
)
4248

4349
val weights = LlamaRuntimeWeights(
@@ -52,12 +58,12 @@ class LlamaRuntimeTest {
5258
ropeDimensionCount = headSize,
5359
vocabSize = vocab
5460
),
55-
tokenEmbedding = ctx.full(Shape(vocab, dim), FP32::class, 0.2f),
61+
tokenEmbedding = ctx.full(Shape(dim, vocab), FP32::class, 0.2f), // [dim, vocab]
5662
ropeFreqReal = ropeReal,
5763
ropeFreqImag = ropeImag,
5864
layers = listOf(layer),
5965
outputNorm = ones1d,
60-
outputWeight = ctx.full(Shape(vocab, dim), FP32::class, 0.3f)
66+
outputWeight = ctx.full(Shape(dim, vocab), FP32::class, 0.3f) // [dim, vocab]
6167
)
6268

6369
val runtime = LlamaRuntime(ctx, weights)
@@ -75,10 +81,11 @@ class LlamaRuntimeTest {
7581
val seqLen = 6
7682
val vocab = 4
7783

84+
// GGUF format shapes
7885
val ones1d = ctx.full<FP32, Float>(Shape(dim), FP32::class, 1f)
7986
val ones2d = ctx.full<FP32, Float>(Shape(dim, dim), FP32::class, 0.1f)
80-
val gate = ctx.full<FP32, Float>(Shape(hidden, dim), FP32::class, 0.05f)
81-
val down = ctx.full<FP32, Float>(Shape(dim, hidden), FP32::class, 0.05f)
87+
val gateUp = ctx.full<FP32, Float>(Shape(dim, hidden), FP32::class, 0.05f) // [dim, ff_dim]
88+
val down = ctx.full<FP32, Float>(Shape(hidden, dim), FP32::class, 0.05f) // [ff_dim, dim]
8289
val ropeReal = ctx.full<FP32, Float>(Shape(seqLen, dim / 2), FP32::class, 1f)
8390
val ropeImag = ctx.full<FP32, Float>(Shape(seqLen, dim / 2), FP32::class, 0f)
8491

@@ -89,9 +96,9 @@ class LlamaRuntimeTest {
8996
wv = ones2d,
9097
wo = ones2d,
9198
ffnNorm = ones1d,
92-
ffnGate = gate,
99+
ffnGate = gateUp,
93100
ffnDown = down,
94-
ffnUp = gate
101+
ffnUp = gateUp
95102
)
96103

97104
val weights = LlamaRuntimeWeights(
@@ -106,12 +113,12 @@ class LlamaRuntimeTest {
106113
ropeDimensionCount = dim,
107114
vocabSize = vocab
108115
),
109-
tokenEmbedding = ctx.full(Shape(vocab, dim), FP32::class, 0.2f),
116+
tokenEmbedding = ctx.full(Shape(dim, vocab), FP32::class, 0.2f), // [dim, vocab]
110117
ropeFreqReal = ropeReal,
111118
ropeFreqImag = ropeImag,
112119
layers = listOf(layer),
113120
outputNorm = ones1d,
114-
outputWeight = ctx.full(Shape(vocab, dim), FP32::class, 0.3f)
121+
outputWeight = ctx.full(Shape(dim, vocab), FP32::class, 0.3f) // [dim, vocab]
115122
)
116123

117124
val runtime = LlamaRuntime(ctx, weights)

skainet-apps/skainet-kllama/src/nativeMain/kotlin/sk/ainet/apps/kllama/cli/Main.kt

Lines changed: 8 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,12 @@ import sk.ainet.apps.kllama.GGUFTokenizer
99
import sk.ainet.apps.kllama.LlamaIngestion
1010
import sk.ainet.apps.kllama.LlamaLoadConfig
1111
import sk.ainet.apps.kllama.Tokenizer
12-
import sk.ainet.apps.kllama.TokenizerUtils
1312
import sk.ainet.apps.kllama.LlamaRuntime
1413
import sk.ainet.context.DirectCpuExecutionContext
1514
import sk.ainet.io.gguf.llama.LlamaWeightLoader
1615

1716
private fun usage(): Nothing {
18-
println("Usage: kllama <model-path> <prompt> [tokenizer-path] [steps=64] [temperature=0.8]")
19-
println(" For GGUF models, tokenizer-path is optional (uses embedded tokenizer)")
17+
println("Usage: kllama <model.gguf> <prompt> [steps=64] [temperature=0.8]")
2018
throw IllegalArgumentException("Invalid arguments")
2119
}
2220

@@ -25,52 +23,23 @@ fun main(args: Array<String>) = runBlocking {
2523

2624
val modelPathStr = args[0]
2725
val prompt = args[1]
28-
29-
// Parse remaining args: tokenizer-path is optional for GGUF
30-
var tokenizerPathStr: String? = null
31-
var steps = 64
32-
var temperature = 0.8f
33-
34-
// Check if args[2] is a file path or a number (steps)
35-
if (args.size > 2) {
36-
val arg2 = args[2]
37-
if (arg2.toIntOrNull() != null) {
38-
// It's steps
39-
steps = arg2.toInt()
40-
temperature = args.getOrNull(3)?.toFloatOrNull() ?: 0.8f
41-
} else {
42-
// It's tokenizer path
43-
tokenizerPathStr = arg2
44-
steps = args.getOrNull(3)?.toIntOrNull() ?: 64
45-
temperature = args.getOrNull(4)?.toFloatOrNull() ?: 0.8f
46-
}
47-
}
26+
val steps = args.getOrNull(2)?.toIntOrNull() ?: 64
27+
val temperature = args.getOrNull(3)?.toFloatOrNull() ?: 0.8f
4828

4929
val modelPath = Path(modelPathStr)
5030

5131
if (!SystemFileSystem.exists(modelPath)) {
5232
error("Model not found: $modelPathStr")
5333
}
5434

55-
val modelFormat = when {
56-
modelPathStr.endsWith(".gguf", ignoreCase = true) -> LlamaWeightLoader.Format.GGUF
57-
modelPathStr.endsWith(".bin", ignoreCase = true) -> LlamaWeightLoader.Format.KARPATHY_BIN
58-
else -> error("Unknown model extension. Use .gguf or .bin")
59-
}
60-
61-
// For .bin format, tokenizer is required
62-
if (modelFormat == LlamaWeightLoader.Format.KARPATHY_BIN && tokenizerPathStr == null) {
63-
error("Tokenizer path is required for .bin format models")
64-
}
65-
if (tokenizerPathStr != null && !SystemFileSystem.exists(Path(tokenizerPathStr))) {
66-
error("Tokenizer not found: $tokenizerPathStr")
35+
if (!modelPathStr.endsWith(".gguf", ignoreCase = true)) {
36+
error("Only GGUF format is supported. Use a .gguf model file.")
6737
}
6838

6939
val ctx = DirectCpuExecutionContext()
7040
val ingestion = LlamaIngestion(
7141
ctx = ctx,
7242
config = LlamaLoadConfig(
73-
format = modelFormat,
7443
quantPolicy = LlamaWeightLoader.QuantPolicy.DEQUANTIZE_TO_FP32,
7544
allowQuantized = false
7645
)
@@ -81,14 +50,9 @@ fun main(args: Array<String>) = runBlocking {
8150
SystemFileSystem.source(modelPath).buffered()
8251
}
8352

84-
// Load tokenizer: use embedded GGUF tokenizer if no external path provided
85-
val tokenizer: Tokenizer = if (tokenizerPathStr != null) {
86-
println("Loading tokenizer from $tokenizerPathStr...")
87-
loadTokenizer(Path(tokenizerPathStr), runtimeWeights.metadata.vocabSize)
88-
} else {
89-
println("Using embedded GGUF tokenizer...")
90-
GGUFTokenizer.fromSource(SystemFileSystem.source(modelPath).buffered())
91-
}
53+
// Load embedded GGUF tokenizer
54+
println("Loading embedded GGUF tokenizer...")
55+
val tokenizer: Tokenizer = GGUFTokenizer.fromSource(SystemFileSystem.source(modelPath).buffered())
9256

9357
val runtime = LlamaRuntime(ctx, runtimeWeights)
9458
val promptTokens = tokenizer.encode(prompt)
@@ -106,8 +70,3 @@ fun main(args: Array<String>) = runBlocking {
10670
println("\n---")
10771
println("tok/s: $tokPerSec")
10872
}
109-
110-
private fun loadTokenizer(path: Path, vocabSize: Int): Tokenizer {
111-
val source = SystemFileSystem.source(path).buffered()
112-
return TokenizerUtils.buildTokenizer(source, vocabSize)
113-
}

0 commit comments

Comments
 (0)