Skip to content

Commit 61993fc

Browse files
committed
Replace loadTornadoTensorAsFP32 with loadTornadoTensor across model loaders for consistent tensor loading.
1 parent d0966eb commit 61993fc

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
140140

141141
// Load all tensors uniformly as TornadoTensor hierarchy
142142
return new Phi3TornadoWeights(
143-
loadTornadoTensorAsFP32(tokenEmbeddings),
143+
loadTornadoTensor(tokenEmbeddings),
144144
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
145145
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")),
146146
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),

src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
137137

138138
// Load all tensors uniformly as TornadoTensor hierarchy
139139
return new Qwen2TornadoWeights(
140-
loadTornadoTensorAsFP32(tokenEmbeddings),
140+
loadTornadoTensor(tokenEmbeddings),
141141
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
142142
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
143143
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),

src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
137137
final int nl = config.numberOfLayers();
138138

139139
return new Qwen3TornadoWeights(
140-
loadTornadoTensorAsFP32(tokenEmbeddings),
140+
loadTornadoTensor(tokenEmbeddings),
141141
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
142142
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
143143
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),

0 commit comments

Comments
 (0)