Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions cmd/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/briandowns/spinner"
"github.com/github/gh-models/internal/azuremodels"
"github.com/github/gh-models/internal/modelkey"
"github.com/github/gh-models/internal/sse"
"github.com/github/gh-models/pkg/command"
"github.com/github/gh-models/pkg/prompt"
Expand Down Expand Up @@ -513,9 +514,21 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
return "", errors.New(noMatchErrorMessage)
}

parsedModel, err := modelkey.ParseModelKey(modelName)
if err != nil {
return "", fmt.Errorf("invalid model format: %w", err)
}

if parsedModel.Provider == "custom" {
// Skip validation for custom provider
return parsedModel.String(), nil
}

// For non-custom providers, validate the model exists
expectedModelID := azuremodels.FormatIdentifier(parsedModel.Publisher, parsedModel.ModelName)
foundMatch := false
for _, model := range models {
if model.HasName(modelName) {
if model.HasName(expectedModelID) {
foundMatch = true
break
}
Expand All @@ -525,7 +538,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
return "", errors.New(noMatchErrorMessage)
}

return modelName, nil
return expectedModelID, nil
}

func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions, org string) (sse.Reader[azuremodels.ChatCompletion], error) {
Expand Down
53 changes: 53 additions & 0 deletions cmd/run/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,56 @@ func TestParseTemplateVariables(t *testing.T) {
})
}
}

func TestValidateModelName(t *testing.T) {
tests := []struct {
name string
modelName string
expectedModel string
expectError bool
}{
{
name: "custom provider skips validation",
modelName: "custom/mycompany/custom-model",
expectedModel: "custom/mycompany/custom-model",
expectError: false,
},
{
name: "azureml provider requires validation",
modelName: "openai/gpt-4",
expectedModel: "openai/gpt-4",
expectError: false,
},
{
name: "invalid model format",
modelName: "invalid-format",
expectError: true,
},
{
name: "nonexistent azureml model",
modelName: "nonexistent/model",
expectError: true,
},
}

// Create a mock model for testing
mockModel := &azuremodels.ModelSummary{
Name: "gpt-4",
Publisher: "openai",
Task: "chat-completion",
}
models := []*azuremodels.ModelSummary{mockModel}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := validateModelName(tt.modelName, models)

if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedModel, result)
}
})
}
}
12 changes: 3 additions & 9 deletions internal/azuremodels/model_details.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package azuremodels

import (
"fmt"
"strings"

"github.com/github/gh-models/internal/modelkey"
)

// ModelDetails includes detailed information about a model.
Expand All @@ -28,12 +29,5 @@ func (m *ModelDetails) ContextLimits() string {

// FormatIdentifier formats the model identifier based on the publisher and model name.
func FormatIdentifier(publisher, name string) string {
formatPart := func(s string) string {
// Replace spaces with dashes and convert to lowercase
result := strings.ToLower(s)
result = strings.ReplaceAll(result, " ", "-")
return result
}

return fmt.Sprintf("%s/%s", formatPart(publisher), formatPart(name))
return modelkey.FormatIdentifier("azureml", publisher, name)
}
76 changes: 76 additions & 0 deletions internal/modelkey/modelkey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package modelkey

import (
"fmt"
"strings"
)

type ModelKey struct {
Provider string
Publisher string
ModelName string
}

func ParseModelKey(modelKey string) (*ModelKey, error) {
if modelKey == "" {
return nil, fmt.Errorf("invalid model key format: %s", modelKey)
}

parts := strings.Split(modelKey, "/")

// Check for empty parts
for _, part := range parts {
if part == "" {
return nil, fmt.Errorf("invalid model key format: %s", modelKey)
}
}

switch len(parts) {
case 2:
// Format: publisher/model-name (provider defaults to "azureml")
return &ModelKey{
Provider: "azureml",
Publisher: parts[0],
ModelName: parts[1],
}, nil
case 3:
// Format: provider/publisher/model-name
return &ModelKey{
Provider: parts[0],
Publisher: parts[1],
ModelName: parts[2],
}, nil
default:
return nil, fmt.Errorf("invalid model key format: %s", modelKey)
}
}

// String returns the string representation of the ModelKey.
func (mk *ModelKey) String() string {
provider := formatPart(mk.Provider)
publisher := formatPart(mk.Publisher)
modelName := formatPart(mk.ModelName)

if provider == "azureml" {
return fmt.Sprintf("%s/%s", publisher, modelName)
}

return fmt.Sprintf("%s/%s/%s", provider, publisher, modelName)
}

func formatPart(s string) string {
s = strings.ToLower(s)
s = strings.ReplaceAll(s, " ", "-")

return s
}

func FormatIdentifier(provider, publisher, name string) string {
mk := &ModelKey{
Provider: provider,
Publisher: publisher,
ModelName: name,
}

return mk.String()
}
Loading
Loading