Skip to content

Commit 1d5ff3f

Browse files
committed
Use GitHub Models catalog
1 parent 00de810 commit 1d5ff3f

File tree

13 files changed

+100
-118
lines changed

13 files changed

+100
-118
lines changed

cmd/list/list.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func NewListCommand(cfg *command.Config) *cobra.Command {
5353
printer.EndRow()
5454

5555
for _, model := range models {
56-
printer.AddField(azuremodels.FormatIdentifier(model.Publisher, model.Name))
56+
printer.AddField(model.ID)
5757
printer.AddField(model.FriendlyName)
5858
printer.EndRow()
5959
}

cmd/list/list_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@ func TestList(t *testing.T) {
1414
t.Run("NewListCommand happy path", func(t *testing.T) {
1515
client := azuremodels.NewMockClient()
1616
modelSummary := &azuremodels.ModelSummary{
17-
ID: "test-id-1",
17+
ID: "openai/test-id-1",
1818
Name: "test-model-1",
1919
FriendlyName: "Test Model 1",
2020
Task: "chat-completion",
2121
Publisher: "OpenAI",
2222
Summary: "This is a test model",
2323
Version: "1.0",
24-
RegistryName: "azure-openai",
2524
}
2625
listModelsCallCount := 0
2726
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
@@ -41,7 +40,7 @@ func TestList(t *testing.T) {
4140
require.Contains(t, output, "DISPLAY NAME")
4241
require.Contains(t, output, "ID")
4342
require.Contains(t, output, modelSummary.FriendlyName)
44-
require.Contains(t, output, azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name))
43+
require.Contains(t, output, modelSummary.ID)
4544
})
4645

4746
t.Run("--help prints usage info", func(t *testing.T) {

cmd/run/run.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,8 @@ func (h *runCommandHandler) getModelNameFromArgs(models []*azuremodels.ModelSumm
500500
if !model.IsChatModel() {
501501
continue
502502
}
503-
prompt.Options = append(prompt.Options, azuremodels.FormatIdentifier(model.Publisher, model.Name))
503+
504+
prompt.Options = append(prompt.Options, model.ID)
504505
}
505506

506507
err := survey.AskOne(prompt, &modelName, survey.WithPageSize(10))
@@ -533,7 +534,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
533534
}
534535

535536
// For non-custom providers, validate the model exists
536-
expectedModelID := azuremodels.FormatIdentifier(parsedModel.Publisher, parsedModel.ModelName)
537+
expectedModelID := parsedModel.String()
537538
foundMatch := false
538539
for _, model := range models {
539540
if model.HasName(expectedModelID) {

cmd/run/run_test.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ func TestRun(t *testing.T) {
1919
t.Run("NewRunCommand happy path", func(t *testing.T) {
2020
client := azuremodels.NewMockClient()
2121
modelSummary := &azuremodels.ModelSummary{
22-
ID: "test-id-1",
22+
ID: "openai/test-model-1",
2323
Name: "test-model-1",
2424
FriendlyName: "Test Model 1",
2525
Task: "chat-completion",
2626
Publisher: "OpenAI",
2727
Summary: "This is a test model",
2828
Version: "1.0",
29-
RegistryName: "azure-openai",
3029
}
3130
listModelsCallCount := 0
3231
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
@@ -52,7 +51,7 @@ func TestRun(t *testing.T) {
5251
buf := new(bytes.Buffer)
5352
cfg := command.NewConfig(buf, buf, client, true, 80)
5453
runCmd := NewRunCommand(cfg)
55-
runCmd.SetArgs([]string{azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name), "this is my prompt"})
54+
runCmd.SetArgs([]string{modelSummary.ID, "this is my prompt"})
5655

5756
_, err := runCmd.ExecuteC()
5857

@@ -104,6 +103,7 @@ messages:
104103

105104
client := azuremodels.NewMockClient()
106105
modelSummary := &azuremodels.ModelSummary{
106+
ID: "openai/test-model",
107107
Name: "test-model",
108108
Publisher: "openai",
109109
Task: "chat-completion",
@@ -134,7 +134,7 @@ messages:
134134
runCmd := NewRunCommand(cfg)
135135
runCmd.SetArgs([]string{
136136
"--file", tmp.Name(),
137-
azuremodels.FormatIdentifier("openai", "test-model"),
137+
"openai/test-model",
138138
})
139139

140140
_, err = runCmd.ExecuteC()
@@ -170,6 +170,7 @@ messages:
170170

171171
client := azuremodels.NewMockClient()
172172
modelSummary := &azuremodels.ModelSummary{
173+
ID: "openai/test-model",
173174
Name: "test-model",
174175
Publisher: "openai",
175176
Task: "chat-completion",
@@ -214,7 +215,7 @@ messages:
214215
runCmd := NewRunCommand(cfg)
215216
runCmd.SetArgs([]string{
216217
"--file", tmp.Name(),
217-
azuremodels.FormatIdentifier("openai", "test-model"),
218+
"openai/test-model",
218219
initialPrompt,
219220
})
220221

@@ -252,11 +253,13 @@ messages:
252253

253254
client := azuremodels.NewMockClient()
254255
modelSummary := &azuremodels.ModelSummary{
256+
ID: "openai/example-model",
255257
Name: "example-model",
256258
Publisher: "openai",
257259
Task: "chat-completion",
258260
}
259261
modelSummary2 := &azuremodels.ModelSummary{
262+
ID: "openai/example-model-4o-mini-plus",
260263
Name: "example-model-4o-mini-plus",
261264
Publisher: "openai",
262265
Task: "chat-completion",
@@ -369,6 +372,7 @@ messages:
369372

370373
client := azuremodels.NewMockClient()
371374
modelSummary := &azuremodels.ModelSummary{
375+
ID: "openai/test-model",
372376
Name: "test-model",
373377
Publisher: "openai",
374378
Task: "chat-completion",
@@ -533,6 +537,7 @@ func TestValidateModelName(t *testing.T) {
533537

534538
// Create a mock model for testing
535539
mockModel := &azuremodels.ModelSummary{
540+
ID: "openai/gpt-4",
536541
Name: "gpt-4",
537542
Publisher: "openai",
538543
Task: "chat-completion",

cmd/view/view.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func NewViewCommand(cfg *command.Config) *cobra.Command {
5050
if !model.IsChatModel() {
5151
continue
5252
}
53-
prompt.Options = append(prompt.Options, azuremodels.FormatIdentifier(model.Publisher, model.Name))
53+
prompt.Options = append(prompt.Options, model.ID)
5454
}
5555

5656
err = survey.AskOne(prompt, &modelName, survey.WithPageSize(10))
@@ -61,13 +61,12 @@ func NewViewCommand(cfg *command.Config) *cobra.Command {
6161
case len(args) >= 1:
6262
modelName = args[0]
6363
}
64-
6564
modelSummary, err := getModelByName(modelName, models)
6665
if err != nil {
6766
return err
6867
}
6968

70-
modelDetails, err := client.GetModelDetails(ctx, modelSummary.RegistryName, modelSummary.Name, modelSummary.Version)
69+
modelDetails, err := client.GetModelDetails(ctx, modelSummary.Registry, modelSummary.Name, modelSummary.Version)
7170
if err != nil {
7271
return err
7372
}

cmd/view/view_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@ func TestView(t *testing.T) {
1414
t.Run("NewViewCommand happy path", func(t *testing.T) {
1515
client := azuremodels.NewMockClient()
1616
modelSummary := &azuremodels.ModelSummary{
17-
ID: "test-id-1",
17+
ID: "openai/test-model-1",
1818
Name: "test-model-1",
1919
FriendlyName: "Test Model 1",
2020
Task: "chat-completion",
2121
Publisher: "OpenAI",
2222
Summary: "This is a test model",
2323
Version: "1.0",
24-
RegistryName: "azure-openai",
2524
}
2625
listModelsCallCount := 0
2726
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
@@ -49,7 +48,7 @@ func TestView(t *testing.T) {
4948
buf := new(bytes.Buffer)
5049
cfg := command.NewConfig(buf, buf, client, true, 80)
5150
viewCmd := NewViewCommand(cfg)
52-
viewCmd.SetArgs([]string{azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name)})
51+
viewCmd.SetArgs([]string{modelSummary.ID})
5352

5453
_, err := viewCmd.ExecuteC()
5554

internal/azuremodels/azure_client.go

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ import (
99
"fmt"
1010
"io"
1111
"net/http"
12+
"slices"
1213
"strings"
1314

1415
"github.com/cli/go-gh/v2/pkg/api"
16+
"github.com/github/gh-models/internal/modelkey"
1517
"github.com/github/gh-models/internal/sse"
1618
"golang.org/x/text/language"
1719
"golang.org/x/text/language/display"
@@ -185,19 +187,7 @@ func lowercaseStrings(input []string) []string {
185187

186188
// ListModels returns a list of available models.
187189
func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) {
188-
body := bytes.NewReader([]byte(`
189-
{
190-
"filters": [
191-
{ "field": "freePlayground", "values": ["true"], "operator": "eq"},
192-
{ "field": "labels", "values": ["latest"], "operator": "eq"}
193-
],
194-
"order": [
195-
{ "field": "displayName", "direction": "asc" }
196-
]
197-
}
198-
`))
199-
200-
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.ModelsURL, body)
190+
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.ModelsURL, nil)
201191
if err != nil {
202192
return nil, err
203193
}
@@ -218,28 +208,34 @@ func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) {
218208
decoder := json.NewDecoder(resp.Body)
219209
decoder.UseNumber()
220210

221-
var searchResponse modelCatalogSearchResponse
222-
err = decoder.Decode(&searchResponse)
211+
var catalog githubModelCatalogResponse
212+
err = decoder.Decode(&catalog)
223213
if err != nil {
224214
return nil, err
225215
}
226216

227-
models := make([]*ModelSummary, 0, len(searchResponse.Summaries))
228-
for _, summary := range searchResponse.Summaries {
217+
models := make([]*ModelSummary, 0, len(catalog))
218+
for _, catalogModel := range catalog {
219+
// Determine task from supported modalities - if it supports text input/output, it's likely a chat model
229220
inferenceTask := ""
230-
if len(summary.InferenceTasks) > 0 {
231-
inferenceTask = summary.InferenceTasks[0]
221+
if slices.Contains(catalogModel.SupportedInputModalities, "text") && slices.Contains(catalogModel.SupportedOutputModalities, "text") {
222+
inferenceTask = "chat-completion"
223+
}
224+
225+
modelKey, err := modelkey.ParseModelKey(catalogModel.ID)
226+
if err != nil {
227+
return nil, fmt.Errorf("parsing model key %q: %w", catalogModel.ID, err)
232228
}
233229

234230
models = append(models, &ModelSummary{
235-
ID: summary.AssetID,
236-
Name: summary.Name,
237-
FriendlyName: summary.DisplayName,
231+
ID: catalogModel.ID,
232+
Name: modelKey.ModelName,
233+
Registry: catalogModel.Registry,
234+
FriendlyName: catalogModel.Name,
238235
Task: inferenceTask,
239-
Publisher: summary.Publisher,
240-
Summary: summary.Summary,
241-
Version: summary.Version,
242-
RegistryName: summary.RegistryName,
236+
Publisher: catalogModel.Publisher,
237+
Summary: catalogModel.Summary,
238+
Version: catalogModel.Version,
243239
})
244240
}
245241

internal/azuremodels/azure_client_test.go

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -194,38 +194,39 @@ func TestAzureClient(t *testing.T) {
194194
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
195195
require.Equal(t, "application/json", r.Header.Get("Content-Type"))
196196
require.Equal(t, "/", r.URL.Path)
197-
require.Equal(t, http.MethodPost, r.Method)
197+
require.Equal(t, http.MethodGet, r.Method)
198198

199199
handlerFn(w, r)
200200
}))
201201
}
202202

203203
t.Run("happy path", func(t *testing.T) {
204-
summary1 := modelCatalogSearchSummary{
205-
AssetID: "test-id-1",
206-
Name: "test-model-1",
207-
DisplayName: "I Can't Believe It's Not a Real Model",
208-
InferenceTasks: []string{"this model has an inference task but the other model will not"},
209-
Publisher: "OpenAI",
210-
Summary: "This is a test model",
211-
Version: "1.0",
212-
RegistryName: "azure-openai",
213-
}
214-
summary2 := modelCatalogSearchSummary{
215-
AssetID: "test-id-2",
216-
Name: "test-model-2",
217-
DisplayName: "Down the Rabbit-Hole",
218-
Publisher: "Project Gutenberg",
219-
Summary: "The first chapter of Alice's Adventures in Wonderland by Lewis Carroll.",
220-
Version: "THE MILLENNIUM FULCRUM EDITION 3.0",
221-
RegistryName: "proj-gutenberg-website",
204+
summary1 := githubModelSummary{
205+
ID: "openai/gpt-4.1",
206+
Name: "OpenAI GPT-4.1",
207+
Publisher: "OpenAI",
208+
Summary: "gpt-4.1 outperforms gpt-4o across the board",
209+
Version: "1",
210+
RateLimitTier: "high",
211+
SupportedInputModalities: []string{"text", "image"},
212+
SupportedOutputModalities: []string{"text"},
213+
Tags: []string{"multipurpose", "multilingual", "multimodal"},
222214
}
223-
searchResponse := &modelCatalogSearchResponse{
224-
Summaries: []modelCatalogSearchSummary{summary1, summary2},
215+
summary2 := githubModelSummary{
216+
ID: "openai/gpt-4.1-mini",
217+
Name: "OpenAI GPT-4.1-mini",
218+
Publisher: "OpenAI",
219+
Summary: "gpt-4.1-mini outperform gpt-4o-mini across the board",
220+
Version: "2",
221+
RateLimitTier: "low",
222+
SupportedInputModalities: []string{"text", "image"},
223+
SupportedOutputModalities: []string{"text"},
224+
Tags: []string{"multipurpose", "multilingual", "multimodal"},
225225
}
226+
githubResponse := githubModelCatalogResponse{summary1, summary2}
226227
testServer := newTestServerForListModels(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
227228
w.WriteHeader(http.StatusOK)
228-
err := json.NewEncoder(w).Encode(searchResponse)
229+
err := json.NewEncoder(w).Encode(githubResponse)
229230
require.NoError(t, err)
230231
}))
231232
defer testServer.Close()
@@ -238,22 +239,20 @@ func TestAzureClient(t *testing.T) {
238239
require.NoError(t, err)
239240
require.NotNil(t, models)
240241
require.Equal(t, 2, len(models))
241-
require.Equal(t, summary1.AssetID, models[0].ID)
242-
require.Equal(t, summary2.AssetID, models[1].ID)
243-
require.Equal(t, summary1.Name, models[0].Name)
244-
require.Equal(t, summary2.Name, models[1].Name)
245-
require.Equal(t, summary1.DisplayName, models[0].FriendlyName)
246-
require.Equal(t, summary2.DisplayName, models[1].FriendlyName)
247-
require.Equal(t, summary1.InferenceTasks[0], models[0].Task)
248-
require.Empty(t, models[1].Task)
242+
require.Equal(t, summary1.ID, models[0].ID)
243+
require.Equal(t, summary2.ID, models[1].ID)
244+
require.Equal(t, "gpt-4.1", models[0].Name)
245+
require.Equal(t, "gpt-4.1-mini", models[1].Name)
246+
require.Equal(t, summary1.Name, models[0].FriendlyName)
247+
require.Equal(t, summary2.Name, models[1].FriendlyName)
248+
require.Equal(t, "chat-completion", models[0].Task)
249+
require.Equal(t, "chat-completion", models[1].Task)
249250
require.Equal(t, summary1.Publisher, models[0].Publisher)
250251
require.Equal(t, summary2.Publisher, models[1].Publisher)
251252
require.Equal(t, summary1.Summary, models[0].Summary)
252253
require.Equal(t, summary2.Summary, models[1].Summary)
253-
require.Equal(t, summary1.Version, models[0].Version)
254-
require.Equal(t, summary2.Version, models[1].Version)
255-
require.Equal(t, summary1.RegistryName, models[0].RegistryName)
256-
require.Equal(t, summary2.RegistryName, models[1].RegistryName)
254+
require.Equal(t, "1", models[0].Version)
255+
require.Equal(t, "2", models[1].Version)
257256
})
258257

259258
t.Run("handles non-OK status", func(t *testing.T) {

internal/azuremodels/model_details.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ package azuremodels
22

33
import (
44
"fmt"
5-
6-
"github.com/github/gh-models/internal/modelkey"
75
)
86

97
// ModelDetails includes detailed information about a model.
@@ -26,8 +24,3 @@ type ModelDetails struct {
2624
func (m *ModelDetails) ContextLimits() string {
2725
return fmt.Sprintf("up to %d input tokens and %d output tokens", m.MaxInputTokens, m.MaxOutputTokens)
2826
}
29-
30-
// FormatIdentifier formats the model identifier based on the publisher and model name.
31-
func FormatIdentifier(publisher, name string) string {
32-
return modelkey.FormatIdentifier("azureml", publisher, name)
33-
}

internal/azuremodels/model_details_test.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,4 @@ func TestModelDetails(t *testing.T) {
1212
result := details.ContextLimits()
1313
require.Equal(t, "up to 123 input tokens and 456 output tokens", result)
1414
})
15-
16-
t.Run("FormatIdentifier", func(t *testing.T) {
17-
publisher := "Open AI"
18-
name := "GPT 3"
19-
expected := "open-ai/gpt-3"
20-
result := FormatIdentifier(publisher, name)
21-
require.Equal(t, expected, result)
22-
})
2315
}

0 commit comments

Comments
 (0)