@@ -226,4 +226,108 @@ messages:
226
226
227
227
require .Contains (t , out .String (), reply )
228
228
})
229
+
230
+ t .Run ("cli flags override params set in the prompt.yaml file" , func (t * testing.T ) {
231
+ // Begin setup:
232
+ const yamlBody = `
233
+ name: Example Prompt
234
+ description: Example description
235
+ model: openai/example-model
236
+ modelParameters:
237
+ maxTokens: 300
238
+ temperature: 0.8
239
+ topP: 0.9
240
+ messages:
241
+ - role: system
242
+ content: System message
243
+ - role: user
244
+ content: User message
245
+ `
246
+ tmp , err := os .CreateTemp (t .TempDir (), "*.prompt.yaml" )
247
+ require .NoError (t , err )
248
+ _ , err = tmp .WriteString (yamlBody )
249
+ require .NoError (t , err )
250
+ require .NoError (t , tmp .Close ())
251
+
252
+ client := azuremodels .NewMockClient ()
253
+ modelSummary := & azuremodels.ModelSummary {
254
+ Name : "example-model" ,
255
+ Publisher : "openai" ,
256
+ Task : "chat-completion" ,
257
+ }
258
+ modelSummary2 := & azuremodels.ModelSummary {
259
+ Name : "example-model-4o-mini-plus" ,
260
+ Publisher : "openai" ,
261
+ Task : "chat-completion" ,
262
+ }
263
+
264
+ client .MockListModels = func (ctx context.Context ) ([]* azuremodels.
265
+ ModelSummary , error ) {
266
+ return []* azuremodels.ModelSummary {modelSummary , modelSummary2 }, nil
267
+ }
268
+
269
+ var capturedReq azuremodels.ChatCompletionOptions
270
+ reply := "hello"
271
+ chatCompletion := azuremodels.ChatCompletion {
272
+ Choices : []azuremodels.ChatChoice {{
273
+ Message : & azuremodels.ChatChoiceMessage {
274
+ Content : util .Ptr (reply ),
275
+ Role : util .Ptr (string (azuremodels .ChatMessageRoleAssistant )),
276
+ },
277
+ }},
278
+ }
279
+
280
+ client .MockGetChatCompletionStream = func (ctx context.Context , opt azuremodels.ChatCompletionOptions ) (* azuremodels.ChatCompletionResponse , error ) {
281
+ capturedReq = opt
282
+ return & azuremodels.ChatCompletionResponse {
283
+ Reader : sse .NewMockEventReader ([]azuremodels.ChatCompletion {chatCompletion }),
284
+ }, nil
285
+ }
286
+
287
+ out := new (bytes.Buffer )
288
+ cfg := command .NewConfig (out , out , client , true , 100 )
289
+ runCmd := NewRunCommand (cfg )
290
+
291
+ // End setup.
292
+ // ---
293
+ // We're finally ready to start making assertions.
294
+
295
+ // Test case 1: with no flags, the model params come from the YAML file
296
+ runCmd .SetArgs ([]string {
297
+ "--file" , tmp .Name (),
298
+ })
299
+
300
+ _ , err = runCmd .ExecuteC ()
301
+ require .NoError (t , err )
302
+
303
+ require .Equal (t , "openai/example-model" , capturedReq .Model )
304
+ require .Equal (t , 300 , * capturedReq .MaxTokens )
305
+ require .Equal (t , 0.8 , * capturedReq .Temperature )
306
+ require .Equal (t , 0.9 , * capturedReq .TopP )
307
+
308
+ require .Equal (t , "System message" , * capturedReq .Messages [0 ].Content )
309
+ require .Equal (t , "User message" , * capturedReq .Messages [1 ].Content )
310
+
311
+ // Hooray!
312
+ // Test case 2: values from flags override the params from the YAML file
313
+ runCmd = NewRunCommand (cfg )
314
+ runCmd .SetArgs ([]string {
315
+ "openai/example-model-4o-mini-plus" ,
316
+ "--file" , tmp .Name (),
317
+ "--max-tokens" , "150" ,
318
+ "--temperature" , "0.1" ,
319
+ "--top-p" , "0.3" ,
320
+ })
321
+
322
+ _ , err = runCmd .ExecuteC ()
323
+ require .NoError (t , err )
324
+
325
+ require .Equal (t , "openai/example-model-4o-mini-plus" , capturedReq .Model )
326
+ require .Equal (t , 150 , * capturedReq .MaxTokens )
327
+ require .Equal (t , 0.1 , * capturedReq .Temperature )
328
+ require .Equal (t , 0.3 , * capturedReq .TopP )
329
+
330
+ require .Equal (t , "System message" , * capturedReq .Messages [0 ].Content )
331
+ require .Equal (t , "User message" , * capturedReq .Messages [1 ].Content )
332
+ })
229
333
}
0 commit comments