@@ -38,6 +38,18 @@ public interface Model {
3838
3939 State createNewState (int batchsize );
4040
41+ default boolean shouldAddBeginOfText () {
42+ return true ;
43+ }
44+
45+ default boolean shouldAddSystemPrompt () {
46+ return true ;
47+ }
48+
49+ default boolean shouldIncludeReasoning () {
50+ return false ;
51+ }
52+
4153 /**
4254 * Wrapper for invoking the model-specific forward pass via InferenceCore.
4355 *
@@ -68,11 +80,11 @@ default void runInteractive(Sampler sampler, Options options) {
6880 ChatFormat chatFormat = chatFormat ();
6981 TornadoVMMasterPlan tornadoVMPlan = null ;
7082
71- if (! getModelType (). equals ( ModelType . QWEN_3 ) && ! getModelType (). equals ( ModelType . PHI_3 )) {
83+ if (shouldAddBeginOfText ( )) {
7284 conversationTokens .add (chatFormat .getBeginOfText ());
7385 }
7486
75- if (options .systemPrompt () != null ) {
87+ if (shouldAddSystemPrompt () && options .systemPrompt () != null ) {
7688 conversationTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .SYSTEM , options .systemPrompt ())));
7789 }
7890
@@ -95,6 +107,18 @@ default void runInteractive(Sampler sampler, Options options) {
95107
96108 conversationTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , userText )));
97109 conversationTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
110+
111+ // Include reasoning for Deepseek-R1-Distill-Qwen
112+ if (shouldIncludeReasoning ()) {
113+ List <Integer > thinkStartTokens = tokenizer ().encode ("<think>\n " , tokenizer ().getSpecialTokens ().keySet ());
114+ conversationTokens .addAll (thinkStartTokens );
115+
116+ // If streaming, immediately output the think start
117+ if (options .stream ()) {
118+ System .out .print ("<think>\n " );
119+ }
120+ }
121+
98122 Set <Integer > stopTokens = chatFormat .getStopTokens ();
99123
100124 List <Integer > responseTokens ;
@@ -127,6 +151,10 @@ default void runInteractive(Sampler sampler, Options options) {
127151 }
128152 if (!options .stream ()) {
129153 String responseText = tokenizer ().decode (responseTokens );
154+ // Add the forced <think>\n prefix for non-streaming output
155+ if (shouldIncludeReasoning ()) {
156+ responseText = "<think>\n " + responseText ;
157+ }
130158 System .out .println (responseText );
131159 }
132160 if (stopToken == null ) {
@@ -164,11 +192,11 @@ default void runInstructOnce(Sampler sampler, Options options) {
164192
165193 List <Integer > promptTokens = new ArrayList <>();
166194
167- if (! getModelType (). equals ( ModelType . QWEN_3 ) && ! getModelType (). equals ( ModelType . QWEN_2 ) && ! getModelType (). equals ( ModelType . PHI_3 )) {
195+ if (shouldAddBeginOfText ( )) {
168196 promptTokens .add (chatFormat .getBeginOfText ());
169197 }
170198
171- if (options .systemPrompt () != null ) {
199+ if (shouldAddSystemPrompt () && options .systemPrompt () != null ) {
172200 promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .SYSTEM , options .systemPrompt ())));
173201 }
174202
@@ -180,6 +208,17 @@ default void runInstructOnce(Sampler sampler, Options options) {
180208 promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , options .prompt ())));
181209 promptTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
182210
211+ // Include reasoning for Deepseek-R1-Distill-Qwen
212+ if (shouldIncludeReasoning ()) {
213+ List <Integer > thinkStartTokens = tokenizer ().encode ("<think>\n " , tokenizer ().getSpecialTokens ().keySet ());
214+ promptTokens .addAll (thinkStartTokens );
215+
216+ // If streaming, immediately output the think start
217+ if (options .stream ()) {
218+ System .out .print ("<think>\n " );
219+ }
220+ }
221+
183222 List <Integer > responseTokens ;
184223
185224 IntConsumer tokenConsumer = token -> {
@@ -206,6 +245,10 @@ default void runInstructOnce(Sampler sampler, Options options) {
206245 }
207246 if (!options .stream ()) {
208247 String responseText = tokenizer ().decode (responseTokens );
248+ // Add the forced <think>\n prefix for non-streaming output
249+ if (shouldIncludeReasoning ()) {
250+ responseText = "<think>\n " + responseText ;
251+ }
209252 System .out .println (responseText );
210253 }
211254
0 commit comments