Skip to content

Commit 4219088

Browse files
michalharakalclaude
andcommitted
feat(skainet-cli): swap LLaMA/Qwen branch to DSL path
Phase 5b consumer migration. Mirrors PR #122 (kllama CLI) and #123 (KLlamaJava facade). After this merge, no top-level CLI in this repo constructs `LlamaRuntime` for the GGUF path. `skainet-cli` previously routed Gemma + Apertus through DSL but kept LLaMA / Qwen / Mistral on the legacy `LlamaRuntime` + `CpuAttentionBackend` + `LlamaWeightMapper` + `MemSegWeightConverter` chain. This PR collapses the else branch onto the DSL path: - `DecoderGgufWeightLoader(NATIVE_OPTIMIZED, family.architectures + [arch])` → `DecoderGgufMemSegConverter.convert` → per-family network loader → `OptimizedLLMRuntime` DIRECT mode. - Family dispatch on the DSL side: `ModelFamily.QWEN` → `QwenNetworkLoader.fromWeights` (NEOX RoPE + QK-norm), else → `LlamaNetworkLoader.fromWeights`. Previously this CLI handled Qwen via the `LlamaRuntime`-with-detected-flags hybrid that the kllama CLI also used pre-#121 — same architectural collapse here. Imports cleaned: removed `CpuAttentionBackend`, `LlamaRuntime`, `LlamaWeightMapper`, `MemSegWeightConverter`. Added `:llm-inference:qwen` to the build.gradle dependencies (was missing — only the legacy hybrid-Qwen path didn't need it). Numerical equivalence with the legacy path on identical weights is pinned by `QwenDslLegacyParityTest` (#120). Tests pass: `:llm-apps:skainet-cli:build`, `:llm-runtime:kllama:jvmTest`, `:llm-inference:qwen:jvmTest`, `:llm-inference:llama:jvmTest`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 124064d commit 4219088

2 files changed

Lines changed: 39 additions & 30 deletions

File tree

llm-apps/skainet-cli/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies {
2020

2121
// Inference modules (for network loaders)
2222
implementation(project(":llm-inference:llama"))
23+
implementation(project(":llm-inference:qwen"))
2324
implementation(project(":llm-inference:gemma"))
2425
implementation(project(":llm-inference:apertus"))
2526

llm-apps/skainet-cli/src/main/kotlin/sk/ainet/apps/skainet/cli/Main.kt

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package sk.ainet.apps.skainet.cli
22

3-
import sk.ainet.apps.kllama.CpuAttentionBackend
43
import sk.ainet.apps.kllama.cli.AgentCli
54
import sk.ainet.apps.kllama.cli.ToolCallingDemo
65
import sk.ainet.apps.llm.InferenceRuntime
@@ -24,10 +23,10 @@ import sk.ainet.io.JvmRandomAccessSource
2423
import sk.ainet.io.model.QuantPolicy
2524
import sk.ainet.lang.tensor.data.MemorySegmentTensorDataFactory
2625
import sk.ainet.lang.types.FP32
27-
import sk.ainet.models.llama.LlamaRuntime
26+
import sk.ainet.models.llama.DecoderGgufMemSegConverter
2827
import sk.ainet.models.llama.DecoderGgufWeightLoader
29-
import sk.ainet.models.llama.LlamaWeightMapper
30-
import sk.ainet.models.llama.MemSegWeightConverter
28+
import sk.ainet.models.llama.LlamaNetworkLoader
29+
import sk.ainet.models.qwen.QwenNetworkLoader
3130
import java.lang.foreign.Arena
3231
import java.nio.file.Path
3332
import kotlin.io.path.exists
@@ -164,15 +163,18 @@ fun main(args: Array<String>) {
164163
memSegFactory.close()
165164
})
166165

167-
// Load model based on detected family. Gemma and Apertus route
168-
// through the DSL pipeline (their respective network() builder +
169-
// OptimizedLLMRuntime); everything else (LLaMA, Qwen, ...) takes
170-
// the LlamaRuntime path which supports NATIVE_OPTIMIZED quant
171-
// tensors for low-RAM loads. Apertus had previously fallen
172-
// through to the LlamaRuntime branch — that runtime doesn't
173-
// implement Apertus's xIELU activation, QK-Norm, or ungated FFN,
174-
// so logits silently diverged from the checkpoint's intent. See
175-
// APERTUS_ROLLOUT.md (PR 1) for the rollout context.
166+
// Load model based on detected family. All families route through
167+
// the DSL pipeline (per-family network() builder +
168+
// OptimizedLLMRuntime). The legacy LlamaRuntime path was retired
169+
// for the kllama CLI in #121 / #122; this CLI follows in this PR.
170+
// Numerical equivalence with the legacy path on identical weights
171+
// is pinned by `QwenDslLegacyParityTest` (#120).
172+
//
173+
// Apertus had previously fallen through to the LlamaRuntime
174+
// branch — that runtime doesn't implement Apertus's xIELU
175+
// activation, QK-Norm, or ungated FFN, so logits silently
176+
// diverged from the checkpoint's intent. The DSL path is correct
177+
// for Apertus too. See APERTUS_ROLLOUT.md (PR 1).
176178
val runtime: InferenceRuntime<FP32> = if (modelInfo.family == ModelFamily.GEMMA) {
177179
println("Loading Gemma GGUF model from $modelPath via gemmaNetwork() + OptimizedLLMRuntime (NATIVE_OPTIMIZED)...")
178180
if (cliArgs.contextLength != null) {
@@ -197,38 +199,44 @@ fun main(args: Array<String>) {
197199
).load<FP32, Float>(ctx)
198200
OptimizedLLMRuntime(model, ctx, OptimizedLLMMode.DIRECT, FP32::class)
199201
} else {
202+
// LLaMA / Qwen / Mistral DSL path. DecoderGgufWeightLoader
203+
// streams the GGUF, DecoderGgufMemSegConverter wraps Q4_0/Q8_0
204+
// tensors as packed MemorySegment data, then the per-family
205+
// network loader builds the right module:
206+
// - Qwen → qwenNetwork() (QK-norm + NEOX RoPE)
207+
// - else → llamaNetwork() (LLaMA / Mistral default)
200208
val acceptedArchitectures = modelInfo.family.architectures + setOf(modelInfo.architecture)
201209
val loader = DecoderGgufWeightLoader(
202210
randomAccessProvider = { JvmRandomAccessSource.open(modelPath.toString()) },
203211
quantPolicy = QuantPolicy.NATIVE_OPTIMIZED,
204-
acceptedArchitectures = acceptedArchitectures
212+
acceptedArchitectures = acceptedArchitectures,
205213
)
206214

207-
println("Loading GGUF model from $modelPath (${modelInfo.family.displayName}, streaming)...")
208-
val loaded = loader.loadToMapStreaming<FP32, Float>(ctx, FP32::class)
209-
val rawWeights = LlamaWeightMapper.map(loaded)
215+
println("Loading GGUF model from $modelPath (${modelInfo.family.displayName}, DSL streaming)...")
216+
val rawWeights = loader.loadToMapStreaming<FP32, Float>(ctx)
210217

211-
val runtimeWeights = if (rawWeights.quantTypes.isNotEmpty()) {
218+
val convertedWeights = if (rawWeights.quantTypes.isNotEmpty()) {
212219
println("Converting ${rawWeights.quantTypes.size} quantized tensors to SIMD format...")
213-
MemSegWeightConverter.convert(rawWeights, ctx, quantArena)
220+
DecoderGgufMemSegConverter.convert(rawWeights, ctx, quantArena)
214221
} else {
215222
rawWeights
216223
}
217224

218225
if (cliArgs.contextLength != null) {
219-
println("Context length capped to ${cliArgs.contextLength} (model default: ${runtimeWeights.metadata.contextLength})")
226+
println("Context length capped to ${cliArgs.contextLength} (model default: ${convertedWeights.metadata.contextLength})")
220227
}
221228

222-
val backend = CpuAttentionBackend<FP32>(
223-
ctx, runtimeWeights, FP32::class,
224-
ropeFreqBase = runtimeWeights.metadata.ropeFreqBase,
225-
maxContextLength = cliArgs.contextLength
226-
)
227-
228-
@Suppress("DEPRECATION")
229-
LlamaRuntime<FP32>(
230-
ctx, runtimeWeights, backend, FP32::class,
231-
eps = runtimeWeights.metadata.rmsNormEps
229+
val model = if (modelInfo.family == ModelFamily.QWEN) {
230+
QwenNetworkLoader.fromWeights(convertedWeights)
231+
} else {
232+
LlamaNetworkLoader.fromWeights(convertedWeights)
233+
}
234+
OptimizedLLMRuntime(
235+
model = model,
236+
ctx = ctx,
237+
mode = OptimizedLLMMode.DIRECT,
238+
dtype = FP32::class,
239+
bos = convertedWeights.metadata.bosTokenId,
232240
)
233241
}
234242

0 commit comments

Comments
 (0)