Skip to content

Commit a46680e

Browse files
Copilotpelikhan
andcommitted
Refactor HTTP log to use Context instead of function parameters
Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com>
1 parent fca40f4 commit a46680e

File tree

15 files changed

+97
-67
lines changed

15 files changed

+97
-67
lines changed

cmd/eval/eval.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,15 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {
127127
evalFile: evalFile,
128128
jsonOutput: jsonOutput,
129129
org: org,
130-
httpLog: httpLog,
131130
}
132131

133-
err = handler.runEvaluation(cmd.Context())
132+
ctx := cmd.Context()
133+
// Add HTTP log filename to context if provided
134+
if httpLog != "" {
135+
ctx = azuremodels.WithHTTPLogFile(ctx, httpLog)
136+
}
137+
138+
err = handler.runEvaluation(ctx)
134139
if err == FailedTests {
135140
// Cobra by default will show the help message when an error occurs,
136141
// which is not what we want for failed evaluations.
@@ -153,7 +158,6 @@ type evalCommandHandler struct {
153158
evalFile *prompt.File
154159
jsonOutput bool
155160
org string
156-
httpLog string
157161
}
158162

159163
func loadEvaluationPromptFile(filePath string) (*prompt.File, error) {
@@ -378,7 +382,7 @@ func (h *evalCommandHandler) callModelWithRetry(ctx context.Context, req azuremo
378382
const maxRetries = 3
379383

380384
for attempt := 0; attempt <= maxRetries; attempt++ {
381-
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org, h.httpLog)
385+
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
382386
if err != nil {
383387
var rateLimitErr *azuremodels.RateLimitError
384388
if errors.As(err, &rateLimitErr) {

cmd/eval/eval_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ evaluators:
162162
cfg := command.NewConfig(out, out, client, true, 100)
163163

164164
// Mock a response that returns "4" for the LLM evaluator
165-
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
165+
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
166166
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{
167167
{
168168
Choices: []azuremodels.ChatChoice{
@@ -228,7 +228,7 @@ evaluators:
228228
client := azuremodels.NewMockClient()
229229

230230
// Mock a simple response
231-
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
231+
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
232232
// Create a mock reader that returns "test response"
233233
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{
234234
{
@@ -284,7 +284,7 @@ evaluators:
284284
client := azuremodels.NewMockClient()
285285

286286
// Mock a response that will fail the evaluator
287-
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
287+
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
288288
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{
289289
{
290290
Choices: []azuremodels.ChatChoice{
@@ -347,7 +347,7 @@ evaluators:
347347

348348
// Mock responses for both test cases
349349
callCount := 0
350-
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
350+
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
351351
callCount++
352352
var response string
353353
if callCount == 1 {
@@ -445,7 +445,7 @@ evaluators:
445445
require.NoError(t, err)
446446

447447
client := azuremodels.NewMockClient()
448-
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
448+
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
449449
response := "hello world"
450450
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{
451451
{
@@ -528,7 +528,7 @@ evaluators:
528528
require.NoError(t, err)
529529

530530
client := azuremodels.NewMockClient()
531-
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
531+
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
532532
response := "hello world"
533533
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{
534534
{
@@ -590,7 +590,7 @@ evaluators:
590590

591591
client := azuremodels.NewMockClient()
592592
var capturedRequest azuremodels.ChatCompletionOptions
593-
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
593+
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
594594
capturedRequest = req
595595
response := `{"message": "hello world", "confidence": 0.95}`
596596
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{

cmd/generate/generate.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ type generateCommandHandler struct {
1717
client azuremodels.Client
1818
options *PromptPexOptions
1919
org string
20-
httpLog string
2120
}
2221

2322
// NewGenerateCommand returns a new command to generate tests using PromptPex.
@@ -54,14 +53,19 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command {
5453
// Get http-log flag
5554
httpLog, _ := cmd.Flags().GetString("http-log")
5655

56+
ctx := cmd.Context()
57+
// Add HTTP log filename to context if provided
58+
if httpLog != "" {
59+
ctx = azuremodels.WithHTTPLogFile(ctx, httpLog)
60+
}
61+
5762
// Create the command handler
5863
handler := &generateCommandHandler{
59-
ctx: cmd.Context(),
64+
ctx: ctx,
6065
cfg: cfg,
6166
client: cfg.Client,
6267
options: options,
6368
org: org,
64-
httpLog: httpLog,
6569
}
6670

6771
// Create PromptPex context

cmd/generate/generate_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ messages:
207207

208208
// Setup mock client to return error
209209
client := azuremodels.NewMockClient()
210-
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
210+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
211211
return nil, errors.New("Mock API error")
212212
}
213213

@@ -241,7 +241,7 @@ messages:
241241
// Setup mock client
242242
client := azuremodels.NewMockClient()
243243
callCount := 0
244-
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
244+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
245245
callCount++
246246
var response string
247247

@@ -314,7 +314,7 @@ messages:
314314

315315
// Setup mock client
316316
client := azuremodels.NewMockClient()
317-
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
317+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
318318
var response string
319319
if len(opt.Messages) > 0 && opt.Messages[0].Content != nil {
320320
content := *opt.Messages[0].Content
@@ -382,7 +382,7 @@ messages:
382382

383383
// Setup mock client
384384
client := azuremodels.NewMockClient()
385-
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
385+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
386386
var response string
387387
if len(opt.Messages) > 0 && opt.Messages[0].Content != nil {
388388
content := *opt.Messages[0].Content
@@ -451,7 +451,7 @@ messages:
451451

452452
// Setup mock client
453453
client := azuremodels.NewMockClient()
454-
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
454+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
455455
var response string
456456
if len(opt.Messages) > 0 && opt.Messages[0].Content != nil {
457457
content := *opt.Messages[0].Content

cmd/generate/llm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func (h *generateCommandHandler) callModelWithRetry(step string, req azuremodels
2424
//nolint:gocritic,revive // TODO
2525
defer sp.Stop()
2626

27-
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org, h.httpLog)
27+
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
2828
if err != nil {
2929
var rateLimitErr *azuremodels.RateLimitError
3030
if errors.As(err, &rateLimitErr) {

cmd/generate/pipeline.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ func (h *generateCommandHandler) runSingleTestWithContext(input, modelName strin
388388
Temperature: util.Ptr(0.0),
389389
}
390390

391-
response, err := h.client.GetChatCompletionStream(h.ctx, options, h.org, h.httpLog)
391+
response, err := h.client.GetChatCompletionStream(h.ctx, options, h.org)
392392
if err != nil {
393393
return "", err
394394
}
@@ -469,7 +469,7 @@ Compliance:`, rules, output)
469469
Temperature: util.Ptr(0.0),
470470
}
471471

472-
response, err := h.client.GetChatCompletionStream(h.ctx, options, h.org, h.httpLog)
472+
response, err := h.client.GetChatCompletionStream(h.ctx, options, h.org)
473473

474474
if err != nil {
475475
return EvalResultUnknown, err
@@ -510,7 +510,7 @@ Score (0-1):`, metric, output)
510510
Temperature: util.Ptr(0.0),
511511
}
512512

513-
response, err := h.client.GetChatCompletionStream(h.ctx, options, h.org, h.httpLog)
513+
response, err := h.client.GetChatCompletionStream(h.ctx, options, h.org)
514514

515515
if err != nil {
516516
return 0.0, err
@@ -615,7 +615,7 @@ Generate variations in JSON format as an array of objects with "scenario", "test
615615
Temperature: util.Ptr(0.5),
616616
}
617617

618-
response, err := h.client.GetChatCompletionStream(h.ctx, options, h.org, h.httpLog)
618+
response, err := h.client.GetChatCompletionStream(h.ctx, options, h.org)
619619

620620
if err != nil {
621621
return nil, err
@@ -676,7 +676,7 @@ Analysis:`, strings.Join(testSummary, "\n"))
676676
Temperature: util.Ptr(0.2),
677677
}
678678

679-
response, err := h.client.GetChatCompletionStream(h.ctx, options, h.org, h.httpLog)
679+
response, err := h.client.GetChatCompletionStream(h.ctx, options, h.org)
680680

681681
if err != nil {
682682
return err

cmd/run/http_log_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ import (
1212
)
1313

1414
func TestHttpLogPassthrough(t *testing.T) {
15-
// Test that the httpLog parameter is correctly passed through the call chain
15+
// Test that the httpLog parameter is correctly passed through the call chain via context
1616
var capturedHttpLog string
1717

1818
client := azuremodels.NewMockClient()
19-
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
20-
capturedHttpLog = httpLogFile
19+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
20+
capturedHttpLog = azuremodels.HTTPLogFileFromContext(ctx)
2121
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{})
2222
return &azuremodels.ChatCompletionResponse{Reader: reader}, nil
2323
}
@@ -26,16 +26,17 @@ func TestHttpLogPassthrough(t *testing.T) {
2626

2727
// Create a command with the http-log flag
2828
cmd := &cobra.Command{}
29+
cmd.SetContext(context.Background()) // Set a context for the command
2930
cmd.Flags().String("http-log", "", "Path to log HTTP requests to (optional)")
3031
cmd.Flags().Set("http-log", "/tmp/test.log")
3132

3233
// Create handler
3334
handler := newRunCommandHandler(cmd, cfg, []string{})
3435

35-
// Test that httpLog is set correctly
36-
require.Equal(t, "/tmp/test.log", handler.httpLog)
36+
// Test that httpLog is correctly stored in context
37+
require.Equal(t, "/tmp/test.log", azuremodels.HTTPLogFileFromContext(handler.ctx))
3738

38-
// Test that it's passed to the client call
39+
// Test that it's passed to the client call via context
3940
req := azuremodels.ChatCompletionOptions{
4041
Model: "test-model",
4142
Messages: []azuremodels.ChatMessage{

cmd/run/run.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -466,16 +466,22 @@ func parseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) {
466466
}
467467

468468
type runCommandHandler struct {
469-
ctx context.Context
470-
cfg *command.Config
471-
client azuremodels.Client
472-
args []string
473-
httpLog string
469+
ctx context.Context
470+
cfg *command.Config
471+
client azuremodels.Client
472+
args []string
474473
}
475474

476475
func newRunCommandHandler(cmd *cobra.Command, cfg *command.Config, args []string) *runCommandHandler {
476+
ctx := cmd.Context()
477477
httpLog, _ := cmd.Flags().GetString("http-log")
478-
return &runCommandHandler{ctx: cmd.Context(), cfg: cfg, client: cfg.Client, args: args, httpLog: httpLog}
478+
479+
// Add HTTP log filename to context if provided
480+
if httpLog != "" {
481+
ctx = azuremodels.WithHTTPLogFile(ctx, httpLog)
482+
}
483+
484+
return &runCommandHandler{ctx: ctx, cfg: cfg, client: cfg.Client, args: args}
479485
}
480486

481487
func (h *runCommandHandler) loadModels() ([]*azuremodels.ModelSummary, error) {
@@ -554,7 +560,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
554560
}
555561

556562
func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions, org string) (sse.Reader[azuremodels.ChatCompletion], error) {
557-
resp, err := h.client.GetChatCompletionStream(h.ctx, req, org, h.httpLog)
563+
resp, err := h.client.GetChatCompletionStream(h.ctx, req, org)
558564
if err != nil {
559565
return nil, err
560566
}

cmd/run/run_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func TestRun(t *testing.T) {
4444
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
4545
}
4646
getChatCompletionCallCount := 0
47-
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
47+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
4848
getChatCompletionCallCount++
4949
return chatResp, nil
5050
}
@@ -122,7 +122,7 @@ messages:
122122
},
123123
}},
124124
}
125-
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
125+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
126126
capturedReq = opt
127127
return &azuremodels.ChatCompletionResponse{
128128
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
@@ -189,7 +189,7 @@ messages:
189189
},
190190
}},
191191
}
192-
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
192+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
193193
capturedReq = opt
194194
return &azuremodels.ChatCompletionResponse{
195195
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
@@ -281,7 +281,7 @@ messages:
281281
}},
282282
}
283283

284-
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
284+
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
285285
capturedReq = opt
286286
return &azuremodels.ChatCompletionResponse{
287287
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
@@ -367,7 +367,7 @@ messages:
367367
}
368368

369369
var capturedRequest azuremodels.ChatCompletionOptions
370-
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org, httpLogFile string) (*azuremodels.ChatCompletionResponse, error) {
370+
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
371371
capturedRequest = req
372372
reply := "hello this is a test response"
373373
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{

internal/azuremodels/azure_client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func NewAzureClient(httpClient *http.Client, authToken string, cfg *AzureClientC
4545
}
4646

4747
// GetChatCompletionStream returns a stream of chat completions using the given options.
48-
func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions, org, httpLogFile string) (*ChatCompletionResponse, error) {
48+
func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions, org string) (*ChatCompletionResponse, error) {
4949
// Check for o1 models, which don't support streaming
5050
if req.Model == "o1-mini" || req.Model == "o1-preview" || req.Model == "o1" {
5151
req.Stream = false
@@ -68,6 +68,7 @@ func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompl
6868
}
6969

7070
// Write request details to specified log file for debugging
71+
httpLogFile := HTTPLogFileFromContext(ctx)
7172
if httpLogFile != "" {
7273
logFile, err := os.OpenFile(httpLogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
7374
if err == nil {

0 commit comments

Comments
 (0)