Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions cmd/eval/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {

Example prompt.yml structure:
name: My Evaluation
model: gpt-4o
model: openai/gpt-4o
testData:
- input: "Hello world"
expected: "Hello there"
Expand All @@ -83,8 +83,11 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {

See https://docs.github.com/github-models/use-github-models/storing-prompts-in-github-repositories#supported-file-format for more information.
`),
Example: "gh models eval my_prompt.prompt.yml",
Args: cobra.ExactArgs(1),
Example: heredoc.Doc(`
gh models eval my_prompt.prompt.yml
gh models eval --org my-org my_prompt.prompt.yml
`),
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
promptFilePath := args[0]

Expand All @@ -94,6 +97,9 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {
return err
}

// Get the org flag
org, _ := cmd.Flags().GetString("org")

// Load the evaluation prompt file
evalFile, err := loadEvaluationPromptFile(promptFilePath)
if err != nil {
Expand All @@ -106,6 +112,7 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {
client: cfg.Client,
evalFile: evalFile,
jsonOutput: jsonOutput,
org: org,
}

err = handler.runEvaluation(cmd.Context())
Expand All @@ -120,6 +127,7 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {
}

cmd.Flags().Bool("json", false, "Output results in JSON format")
cmd.Flags().String("org", "", "Organization to attribute usage to (omitting will attribute usage to the current actor")
return cmd
}

Expand All @@ -128,6 +136,7 @@ type evalCommandHandler struct {
client azuremodels.Client
evalFile *prompt.File
jsonOutput bool
org string
}

func loadEvaluationPromptFile(filePath string) (*prompt.File, error) {
Expand Down Expand Up @@ -321,7 +330,7 @@ func (h *evalCommandHandler) templateString(templateStr string, data map[string]
func (h *evalCommandHandler) callModel(ctx context.Context, messages []azuremodels.ChatMessage) (string, error) {
req := h.evalFile.BuildChatCompletionOptions(messages)

resp, err := h.client.GetChatCompletionStream(ctx, req)
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -460,7 +469,7 @@ func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, e
Stream: false,
}

resp, err := h.client.GetChatCompletionStream(ctx, req)
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
if err != nil {
return EvaluationResult{}, fmt.Errorf("failed to call evaluation model: %w", err)
}
Expand Down
12 changes: 6 additions & 6 deletions cmd/eval/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ evaluators:
cfg := command.NewConfig(out, out, client, true, 100)

// Mock a response that returns "4" for the LLM evaluator
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{
{
Choices: []azuremodels.ChatChoice{
Expand Down Expand Up @@ -228,7 +228,7 @@ evaluators:
client := azuremodels.NewMockClient()

// Mock a simple response
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
// Create a mock reader that returns "test response"
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{
{
Expand Down Expand Up @@ -284,7 +284,7 @@ evaluators:
client := azuremodels.NewMockClient()

// Mock a response that will fail the evaluator
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{
{
Choices: []azuremodels.ChatChoice{
Expand Down Expand Up @@ -346,7 +346,7 @@ evaluators:

// Mock responses for both test cases
callCount := 0
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
callCount++
var response string
if callCount == 1 {
Expand Down Expand Up @@ -444,7 +444,7 @@ evaluators:
require.NoError(t, err)

client := azuremodels.NewMockClient()
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
response := "hello world"
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{
{
Expand Down Expand Up @@ -526,7 +526,7 @@ evaluators:
require.NoError(t, err)

client := azuremodels.NewMockClient()
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
response := "hello world"
reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{
{
Expand Down
12 changes: 9 additions & 3 deletions cmd/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,20 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
When using prompt files, you can pass template variables using the %[1]s--var%[1]s flag:
%[1]sgh models run --file prompt.yml --var name=Alice --var topic=AI%[1]s

When running inference against an organization, pass the organization name using the %[1]s--org%[1]s flag:
%[1]sgh models run --org my-org openai/gpt-4o-mini "What is AI?"%[1]s

The return value will be the response to your prompt from the selected model.
`, "`"),
Example: heredoc.Doc(`
gh models run openai/gpt-4o-mini "how many types of hyena are there?"
gh models run --org my-org openai/gpt-4o-mini "how many types of hyena are there?"
gh models run --file prompt.yml --var name=Alice --var topic="machine learning"
`),
Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error {
filePath, _ := cmd.Flags().GetString("file")
org, _ := cmd.Flags().GetString("org")
var pf *prompt.File
if filePath != "" {
var err error
Expand Down Expand Up @@ -357,7 +362,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
//nolint:gocritic,revive // TODO
defer sp.Stop()

reader, err := cmdHandler.getChatCompletionStreamReader(req)
reader, err := cmdHandler.getChatCompletionStreamReader(req, org)
if err != nil {
return err
}
Expand Down Expand Up @@ -408,6 +413,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
cmd.Flags().String("temperature", "", "Controls randomness in the response, use lower to be more deterministic.")
cmd.Flags().String("top-p", "", "Controls text diversity by selecting the most probable words until a set probability is reached.")
cmd.Flags().String("system-prompt", "", "Prompt the system.")
cmd.Flags().String("org", "", "Organization to attribute usage to (omitting will attribute usage to the current actor")

return cmd
}
Expand Down Expand Up @@ -522,8 +528,8 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
return modelName, nil
}

func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions) (sse.Reader[azuremodels.ChatCompletion], error) {
resp, err := h.client.GetChatCompletionStream(h.ctx, req)
func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions, org string) (sse.Reader[azuremodels.ChatCompletion], error) {
resp, err := h.client.GetChatCompletionStream(h.ctx, req, org)
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions cmd/run/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestRun(t *testing.T) {
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
}
getChatCompletionCallCount := 0
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
getChatCompletionCallCount++
return chatResp, nil
}
Expand Down Expand Up @@ -122,7 +122,7 @@ messages:
},
}},
}
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
capturedReq = opt
return &azuremodels.ChatCompletionResponse{
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
Expand Down Expand Up @@ -188,7 +188,7 @@ messages:
},
}},
}
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
capturedReq = opt
return &azuremodels.ChatCompletionResponse{
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
Expand Down Expand Up @@ -278,7 +278,7 @@ messages:
}},
}

client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
capturedReq = opt
return &azuremodels.ChatCompletionResponse{
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
Expand Down
11 changes: 9 additions & 2 deletions internal/azuremodels/azure_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewAzureClient(httpClient *http.Client, authToken string, cfg *AzureClientC
}

// GetChatCompletionStream returns a stream of chat completions using the given options.
func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions) (*ChatCompletionResponse, error) {
func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions, org string) (*ChatCompletionResponse, error) {
// Check for o1 models, which don't support streaming
if req.Model == "o1-mini" || req.Model == "o1-preview" || req.Model == "o1" {
req.Stream = false
Expand All @@ -55,7 +55,14 @@ func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompl

body := bytes.NewReader(bodyBytes)

httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.InferenceURL, body)
var inferenceURL string
if org != "" {
inferenceURL = fmt.Sprintf("%s/orgs/%s/%s", c.cfg.InferenceRoot, org, c.cfg.InferencePath)
} else {
inferenceURL = c.cfg.InferenceRoot + "/" + c.cfg.InferencePath
}

httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, inferenceURL, body)
if err != nil {
return nil, err
}
Expand Down
9 changes: 6 additions & 3 deletions internal/azuremodels/azure_client_config.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
package azuremodels

const (
defaultInferenceURL = "https://models.github.ai/inference/chat/completions"
defaultInferenceRoot = "https://models.github.ai"
defaultInferencePath = "inference/chat/completions"
defaultAzureAiStudioURL = "https://api.catalog.azureml.ms"
defaultModelsURL = defaultAzureAiStudioURL + "/asset-gallery/v1.0/models"
)

// AzureClientConfig represents configurable settings for the Azure client.
type AzureClientConfig struct {
InferenceURL string
InferenceRoot string
InferencePath string
AzureAiStudioURL string
ModelsURL string
}

// NewDefaultAzureClientConfig returns a new AzureClientConfig with default values for API URLs.
func NewDefaultAzureClientConfig() *AzureClientConfig {
return &AzureClientConfig{
InferenceURL: defaultInferenceURL,
InferenceRoot: defaultInferenceRoot,
InferencePath: defaultInferencePath,
AzureAiStudioURL: defaultAzureAiStudioURL,
ModelsURL: defaultModelsURL,
}
Expand Down
12 changes: 6 additions & 6 deletions internal/azuremodels/azure_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestAzureClient(t *testing.T) {
require.NoError(t, err)
}))
defer testServer.Close()
cfg := &AzureClientConfig{InferenceURL: testServer.URL}
cfg := &AzureClientConfig{InferenceRoot: testServer.URL}
httpClient := testServer.Client()
client := NewAzureClient(httpClient, authToken, cfg)
opts := ChatCompletionOptions{
Expand All @@ -63,7 +63,7 @@ func TestAzureClient(t *testing.T) {
},
}

chatCompletionStreamResp, err := client.GetChatCompletionStream(ctx, opts)
chatCompletionStreamResp, err := client.GetChatCompletionStream(ctx, opts, "")

require.NoError(t, err)
require.NotNil(t, chatCompletionStreamResp)
Expand Down Expand Up @@ -125,7 +125,7 @@ func TestAzureClient(t *testing.T) {
require.NoError(t, err)
}))
defer testServer.Close()
cfg := &AzureClientConfig{InferenceURL: testServer.URL}
cfg := &AzureClientConfig{InferenceRoot: testServer.URL}
httpClient := testServer.Client()
client := NewAzureClient(httpClient, authToken, cfg)
opts := ChatCompletionOptions{
Expand All @@ -139,7 +139,7 @@ func TestAzureClient(t *testing.T) {
},
}

chatCompletionStreamResp, err := client.GetChatCompletionStream(ctx, opts)
chatCompletionStreamResp, err := client.GetChatCompletionStream(ctx, opts, "")

require.NoError(t, err)
require.NotNil(t, chatCompletionStreamResp)
Expand Down Expand Up @@ -173,15 +173,15 @@ func TestAzureClient(t *testing.T) {
require.NoError(t, err)
}))
defer testServer.Close()
cfg := &AzureClientConfig{InferenceURL: testServer.URL}
cfg := &AzureClientConfig{InferenceRoot: testServer.URL}
httpClient := testServer.Client()
client := NewAzureClient(httpClient, "fake-token-123abc", cfg)
opts := ChatCompletionOptions{
Model: "some-test-model",
Messages: []ChatMessage{{Role: "user", Content: util.Ptr("Tell me a story, test model.")}},
}

chatCompletionResp, err := client.GetChatCompletionStream(ctx, opts)
chatCompletionResp, err := client.GetChatCompletionStream(ctx, opts, "")

require.Error(t, err)
require.Nil(t, chatCompletionResp)
Expand Down
2 changes: 1 addition & 1 deletion internal/azuremodels/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import "context"
// Client represents a client for interacting with an API about models.
type Client interface {
// GetChatCompletionStream returns a stream of chat completions using the given options.
GetChatCompletionStream(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error)
GetChatCompletionStream(context.Context, ChatCompletionOptions, string) (*ChatCompletionResponse, error)
// GetModelDetails returns the details of the specified model in a particular registry.
GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error)
// ListModels returns a list of available models.
Expand Down
8 changes: 4 additions & 4 deletions internal/azuremodels/mock_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ import (

// MockClient provides a client for interacting with the Azure models API in tests.
type MockClient struct {
MockGetChatCompletionStream func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error)
MockGetChatCompletionStream func(context.Context, ChatCompletionOptions, string) (*ChatCompletionResponse, error)
MockGetModelDetails func(context.Context, string, string, string) (*ModelDetails, error)
MockListModels func(context.Context) ([]*ModelSummary, error)
}

// NewMockClient returns a new mock client for stubbing out interactions with the models API.
func NewMockClient() *MockClient {
return &MockClient{
MockGetChatCompletionStream: func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) {
MockGetChatCompletionStream: func(context.Context, ChatCompletionOptions, string) (*ChatCompletionResponse, error) {
return nil, errors.New("GetChatCompletionStream not implemented")
},
MockGetModelDetails: func(context.Context, string, string, string) (*ModelDetails, error) {
Expand All @@ -28,8 +28,8 @@ func NewMockClient() *MockClient {
}

// GetChatCompletionStream calls the mocked function for getting a stream of chat completions for the given request.
func (c *MockClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) {
return c.MockGetChatCompletionStream(ctx, opt)
func (c *MockClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions, org string) (*ChatCompletionResponse, error) {
return c.MockGetChatCompletionStream(ctx, opt, org)
}

// GetModelDetails calls the mocked function for getting the details of the specified model in a particular registry.
Expand Down
2 changes: 1 addition & 1 deletion internal/azuremodels/unauthenticated_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func NewUnauthenticatedClient() *UnauthenticatedClient {
}

// GetChatCompletionStream returns an error because this functionality requires authentication.
func (c *UnauthenticatedClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) {
func (c *UnauthenticatedClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions, org string) (*ChatCompletionResponse, error) {
return nil, errors.New("not authenticated")
}

Expand Down
Loading