Skip to content

Commit 2f31074

Browse files
authored
Allow templating {{input}} in prompt files (#46)
2 parents 731b885 + 903c656 commit 2f31074

File tree

2 files changed

+95
-12
lines changed

2 files changed

+95
-12
lines changed

cmd/run/run.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
260260

261261
initialPrompt := ""
262262
singleShot := false
263+
pipedContent := ""
263264

264265
if len(args) > 1 {
265266
initialPrompt = strings.Join(args[1:], " ")
@@ -269,8 +270,11 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
269270
if isPipe(os.Stdin) {
270271
promptFromPipe, _ := io.ReadAll(os.Stdin)
271272
if len(promptFromPipe) > 0 {
272-
initialPrompt = initialPrompt + "\n" + string(promptFromPipe)
273-
singleShot = true
273+
pipedContent = strings.TrimSpace(string(promptFromPipe))
274+
if pf == nil {
275+
initialPrompt = initialPrompt + "\n" + pipedContent
276+
singleShot = true
277+
}
274278
}
275279
}
276280

@@ -283,22 +287,28 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
283287
systemPrompt: systemPrompt,
284288
}
285289

286-
// preload conversation & parameters from YAML
290+
// If a prompt file is passed, load the messages from the file, templating {{input}} from stdin
287291
if pf != nil {
288292
for _, m := range pf.Messages {
293+
content := m.Content
294+
if pipedContent != "" && strings.ToLower(m.Role) == "user" {
295+
content = strings.ReplaceAll(content, "{{input}}", pipedContent)
296+
}
289297
switch strings.ToLower(m.Role) {
290298
case "system":
291299
if conversation.systemPrompt == "" {
292-
conversation.systemPrompt = m.Content
300+
conversation.systemPrompt = content
293301
} else {
294-
conversation.AddMessage(azuremodels.ChatMessageRoleSystem, m.Content)
302+
conversation.AddMessage(azuremodels.ChatMessageRoleSystem, content)
295303
}
296304
case "user":
297-
conversation.AddMessage(azuremodels.ChatMessageRoleUser, m.Content)
305+
conversation.AddMessage(azuremodels.ChatMessageRoleUser, content)
298306
case "assistant":
299-
conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, m.Content)
307+
conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, content)
300308
}
301309
}
310+
311+
initialPrompt = ""
302312
}
303313

304314
mp := ModelParameters{}
@@ -320,7 +330,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
320330
initialPrompt = ""
321331
}
322332

323-
if prompt == "" {
333+
if prompt == "" && pf == nil {
324334
fmt.Printf(">>> ")
325335
reader := bufio.NewReader(os.Stdin)
326336
prompt, err = reader.ReadString('\n')
@@ -331,7 +341,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
331341

332342
prompt = strings.TrimSpace(prompt)
333343

334-
if prompt == "" {
344+
if prompt == "" && pf == nil {
335345
continue
336346
}
337347

@@ -419,7 +429,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
419429

420430
conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, messageBuilder.String())
421431

422-
if singleShot {
432+
if singleShot || pf != nil {
423433
break
424434
}
425435
}

cmd/run/run_test.go

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ messages:
134134
runCmd.SetArgs([]string{
135135
"--file", tmp.Name(),
136136
azuremodels.FormatIdentifier("openai", "test-model"),
137-
"foo?",
138137
})
139138

140139
_, err = runCmd.ExecuteC()
@@ -143,11 +142,85 @@ messages:
143142
require.Equal(t, 3, len(capturedReq.Messages))
144143
require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content)
145144
require.Equal(t, "Hello there!", *capturedReq.Messages[1].Content)
146-
require.Equal(t, "foo?", *capturedReq.Messages[2].Content)
147145

148146
require.NotNil(t, capturedReq.Temperature)
149147
require.Equal(t, 0.5, *capturedReq.Temperature)
150148

151149
require.Contains(t, out.String(), reply) // response streamed to output
152150
})
151+
152+
t.Run("--file with {{input}} placeholder is substituted with stdin", func(t *testing.T) {
153+
const yamlBody = `
154+
name: Summarizer
155+
description: Summarizes input text
156+
model: openai/test-model
157+
messages:
158+
- role: system
159+
content: You are a text summarizer.
160+
- role: user
161+
content: "{{input}}"
162+
`
163+
164+
tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yml")
165+
require.NoError(t, err)
166+
_, err = tmp.WriteString(yamlBody)
167+
require.NoError(t, err)
168+
require.NoError(t, tmp.Close())
169+
170+
client := azuremodels.NewMockClient()
171+
modelSummary := &azuremodels.ModelSummary{
172+
Name: "test-model",
173+
Publisher: "openai",
174+
Task: "chat-completion",
175+
}
176+
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
177+
return []*azuremodels.ModelSummary{modelSummary}, nil
178+
}
179+
180+
var capturedReq azuremodels.ChatCompletionOptions
181+
reply := "Summary - bar"
182+
chatCompletion := azuremodels.ChatCompletion{
183+
Choices: []azuremodels.ChatChoice{{
184+
Message: &azuremodels.ChatChoiceMessage{
185+
Content: util.Ptr(reply),
186+
Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)),
187+
},
188+
}},
189+
}
190+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
191+
capturedReq = opt
192+
return &azuremodels.ChatCompletionResponse{
193+
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
194+
}, nil
195+
}
196+
197+
// create a pipe to fake stdin so that isPipe(os.Stdin)==true
198+
r, w, err := os.Pipe()
199+
require.NoError(t, err)
200+
oldStdin := os.Stdin
201+
os.Stdin = r
202+
defer func() { os.Stdin = oldStdin }()
203+
piped := "Hello there!"
204+
go func() {
205+
_, _ = w.Write([]byte(piped))
206+
_ = w.Close()
207+
}()
208+
209+
out := new(bytes.Buffer)
210+
cfg := command.NewConfig(out, out, client, true, 100)
211+
runCmd := NewRunCommand(cfg)
212+
runCmd.SetArgs([]string{
213+
"--file", tmp.Name(),
214+
azuremodels.FormatIdentifier("openai", "test-model"),
215+
})
216+
217+
_, err = runCmd.ExecuteC()
218+
require.NoError(t, err)
219+
220+
require.Len(t, capturedReq.Messages, 3)
221+
require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content)
222+
require.Equal(t, piped, *capturedReq.Messages[1].Content) // {{input}} -> "Hello there!"
223+
224+
require.Contains(t, out.String(), reply)
225+
})
153226
}

0 commit comments

Comments
 (0)