Skip to content

Commit 12a2f2d

Browse files
authored
Values from flags should override values in prompt yaml files. (#49)
2 parents 81f5047 + cacabc8 commit 12a2f2d

File tree

2 files changed

+109
-4
lines changed

2 files changed

+109
-4
lines changed

cmd/run/run.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,17 +314,18 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
314314
}
315315

316316
mp := ModelParameters{}
317-
err = mp.PopulateFromFlags(cmd.Flags())
318-
if err != nil {
319-
return err
320-
}
321317

322318
if pf != nil {
323319
mp.maxTokens = pf.ModelParameters.MaxTokens
324320
mp.temperature = pf.ModelParameters.Temperature
325321
mp.topP = pf.ModelParameters.TopP
326322
}
327323

324+
err = mp.PopulateFromFlags(cmd.Flags())
325+
if err != nil {
326+
return err
327+
}
328+
328329
for {
329330
prompt := ""
330331
if initialPrompt != "" {

cmd/run/run_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,4 +226,108 @@ messages:
226226

227227
require.Contains(t, out.String(), reply)
228228
})
229+
230+
t.Run("cli flags override params set in the prompt.yaml file", func(t *testing.T) {
231+
// Begin setup:
232+
const yamlBody = `
233+
name: Example Prompt
234+
description: Example description
235+
model: openai/example-model
236+
modelParameters:
237+
maxTokens: 300
238+
temperature: 0.8
239+
topP: 0.9
240+
messages:
241+
- role: system
242+
content: System message
243+
- role: user
244+
content: User message
245+
`
246+
tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yaml")
247+
require.NoError(t, err)
248+
_, err = tmp.WriteString(yamlBody)
249+
require.NoError(t, err)
250+
require.NoError(t, tmp.Close())
251+
252+
client := azuremodels.NewMockClient()
253+
modelSummary := &azuremodels.ModelSummary{
254+
Name: "example-model",
255+
Publisher: "openai",
256+
Task: "chat-completion",
257+
}
258+
modelSummary2 := &azuremodels.ModelSummary{
259+
Name: "example-model-4o-mini-plus",
260+
Publisher: "openai",
261+
Task: "chat-completion",
262+
}
263+
264+
client.MockListModels = func(ctx context.Context) ([]*azuremodels.
265+
ModelSummary, error) {
266+
return []*azuremodels.ModelSummary{modelSummary, modelSummary2}, nil
267+
}
268+
269+
var capturedReq azuremodels.ChatCompletionOptions
270+
reply := "hello"
271+
chatCompletion := azuremodels.ChatCompletion{
272+
Choices: []azuremodels.ChatChoice{{
273+
Message: &azuremodels.ChatChoiceMessage{
274+
Content: util.Ptr(reply),
275+
Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)),
276+
},
277+
}},
278+
}
279+
280+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
281+
capturedReq = opt
282+
return &azuremodels.ChatCompletionResponse{
283+
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
284+
}, nil
285+
}
286+
287+
out := new(bytes.Buffer)
288+
cfg := command.NewConfig(out, out, client, true, 100)
289+
runCmd := NewRunCommand(cfg)
290+
291+
// End setup.
292+
// ---
293+
// We're finally ready to start making assertions.
294+
295+
// Test case 1: with no flags, the model params come from the YAML file
296+
runCmd.SetArgs([]string{
297+
"--file", tmp.Name(),
298+
})
299+
300+
_, err = runCmd.ExecuteC()
301+
require.NoError(t, err)
302+
303+
require.Equal(t, "openai/example-model", capturedReq.Model)
304+
require.Equal(t, 300, *capturedReq.MaxTokens)
305+
require.Equal(t, 0.8, *capturedReq.Temperature)
306+
require.Equal(t, 0.9, *capturedReq.TopP)
307+
308+
require.Equal(t, "System message", *capturedReq.Messages[0].Content)
309+
require.Equal(t, "User message", *capturedReq.Messages[1].Content)
310+
311+
// Hooray!
312+
// Test case 2: values from flags override the params from the YAML file
313+
runCmd = NewRunCommand(cfg)
314+
runCmd.SetArgs([]string{
315+
"openai/example-model-4o-mini-plus",
316+
"--file", tmp.Name(),
317+
"--max-tokens", "150",
318+
"--temperature", "0.1",
319+
"--top-p", "0.3",
320+
})
321+
322+
_, err = runCmd.ExecuteC()
323+
require.NoError(t, err)
324+
325+
require.Equal(t, "openai/example-model-4o-mini-plus", capturedReq.Model)
326+
require.Equal(t, 150, *capturedReq.MaxTokens)
327+
require.Equal(t, 0.1, *capturedReq.Temperature)
328+
require.Equal(t, 0.3, *capturedReq.TopP)
329+
330+
require.Equal(t, "System message", *capturedReq.Messages[0].Content)
331+
require.Equal(t, "User message", *capturedReq.Messages[1].Content)
332+
})
229333
}

0 commit comments

Comments
 (0)