Skip to content

Commit b6d8477

Browse files
Distinct Deepseek-R1-Distill-Qwen from Qwen2
1 parent 0b60c63 commit b6d8477

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

src/main/java/com/example/model/ModelType.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo
5050
}
5151
},
5252

53+
DEEPSEEK_R1_DISTILL_QWEN {
54+
@Override
55+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
56+
return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
57+
}
58+
},
59+
5360
UNKNOWN {
5461
@Override
5562
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
@@ -59,4 +66,8 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo
5966

6067
// Abstract method that each enum constant must implement
6168
public abstract Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights);
69+
70+
public boolean isDeepSeekR1() {
71+
return this == DEEPSEEK_R1_DISTILL_QWEN;
72+
}
6273
}

src/main/java/com/example/model/loader/ModelLoader.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,12 @@ private static ModelType detectModelType(Map<String, Object> metadata) {
6161
return ModelType.MISTRAL;
6262
} else if (lowerName.contains("llama")) {
6363
return ModelType.LLAMA_3;
64-
} else if (lowerName.contains("qwen2") || lowerName.contains("deepseek r1 distill")) {
64+
} else if (lowerName.contains("qwen2")) {
6565
return ModelType.QWEN_2;
6666
} else if (lowerName.contains("qwen3")) {
6767
return ModelType.QWEN_3;
68+
} else if (lowerName.contains("deepseek r1 distill")) {
69+
return ModelType.DEEPSEEK_R1_DISTILL_QWEN;
6870
}
6971
}
7072

src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
9898
TornadoVMLayerPlanner createPlanner(State state, Model model) {
9999
return switch (model.getModelType()) {
100100
case LLAMA_3, MISTRAL -> new TornadoVMLayerPlanner(state, model);
101-
case QWEN_2 -> throw new UnsupportedOperationException("TornadoVM QWEN 2 not supported");
101+
case QWEN_2, DEEPSEEK_R1_DISTILL_QWEN -> throw new UnsupportedOperationException("TornadoVM QWEN 2 not supported");
102102
case QWEN_3 -> new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model);
103103
case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type");
104104
};

0 commit comments

Comments
 (0)