diff --git a/src/functions/describe-model.ts b/src/functions/describe-model.ts index 0481227..ed02e43 100644 --- a/src/functions/describe-model.ts +++ b/src/functions/describe-model.ts @@ -10,8 +10,11 @@ export class describeModel extends Tool { properties: { model: { type: "string", - description: - 'The model to describe. Looks like "registry/model-name". For example, `azureml/Phi-3-medium-128k-instruct` or `azure-openai/gpt-4o', + description: [ + 'The model to describe. Looks like "model-name". For example, `Phi-3-medium-128k-instruct` or `gpt-4o`.', + 'The list of models is available in the context window of the chat, in the `<-- LIST OF MODELS -->` section.', + 'If the model name is not found in the list of models, pick the closest matching model from the list.', + ].join("\n"), }, }, required: ["model"], @@ -30,12 +33,11 @@ export class describeModel extends Tool { const systemMessage = [ "The user is asking about the AI model with the following details:", `\tModel Name: ${model.name}`, - `\tModel Version: ${model.model_version}`, + `\tModel Version: ${model.version}`, `\tPublisher: ${model.publisher}`, - `\tModel Family: ${model.model_family}`, - `\tModel Registry: ${model.model_registry}`, + `\tModel Registry: ${model.registryName}`, `\tLicense: ${model.license}`, - `\tTask: ${model.task}`, + `\tTask: ${model.inferenceTasks.join(", ")}`, `\tDescription: ${model.description}`, `\tSummary: ${model.summary}`, "\n", diff --git a/src/functions/execute-model.ts b/src/functions/execute-model.ts index 405a618..fdd13a1 100644 --- a/src/functions/execute-model.ts +++ b/src/functions/execute-model.ts @@ -30,6 +30,7 @@ Example Queries (IMPORTANT: Phrasing doesn't have to match): "The name of the model to execute. It is ONLY the name of the model, not the publisher or registry.", "For example: `gpt-4o`, or `cohere-command-r-plus`.", "The list of models is available in the context window of the chat, in the `<-- LIST OF MODELS -->` section.", + "If the model name is not found in the list of models, pick the closest matching model from the list.", ].join("\n"), }, instruction: { diff --git a/src/functions/list-models.ts b/src/functions/list-models.ts index 3a47715..c139a8e 100644 --- a/src/functions/list-models.ts +++ b/src/functions/list-models.ts @@ -27,7 +27,7 @@ export class listModels extends Tool { "That list of models is as follows:", JSON.stringify( models.map((model) => ({ - name: model.friendly_name, + name: model.displayName, publisher: model.publisher, description: model.summary, })) diff --git a/src/index.ts b/src/index.ts index ddb7820..f075562 100644 --- a/src/index.ts +++ b/src/index.ts @@ -55,6 +55,7 @@ const server = createServer(async (request, response) => { // List of functions that are available to be called const modelsAPI = new ModelsAPI(apiKey); + const functions = [listModels, describeModel, executeModel, recommendModel]; // Use the Copilot API to determine which function to execute @@ -66,6 +67,7 @@ const server = createServer(async (request, response) => { // Prepend a system message that includes the list of models, so that // tool calls can better select the right model to use. const models = await modelsAPI.listModels(); + const toolCallMessages = [ { role: "system" as const, @@ -75,13 +77,28 @@ const server = createServer(async (request, response) => { "Here is a list of some of the models available to the user:", "<-- LIST OF MODELS -->", JSON.stringify( - models.map((model) => ({ - friendly_name: model.friendly_name, + [...models.map((model) => ({ + friendly_name: model.displayName, name: model.name, publisher: model.publisher, - registry: model.model_registry, + registry: model.registryName, description: model.summary, - })) + })), + { + friendly_name: "OpenAI o1-mini", + name: "o1-mini", + publisher: "openai", + model_registry: "azure-openai", + description: "Smaller, faster, and 80% cheaper than o1-preview, performs well at code generation and small context operations." + }, + { + friendly_name: "OpenAI o1-preview", + name: "o1-preview", + publisher: "openai", + model_registry: "azure-openai", + description: "Focused on advanced reasoning and solving complex problems, including math and science tasks. Ideal for applications that require deep contextual understanding and agentic workflows." + }, + ] ), "<-- END OF LIST OF MODELS -->", ].join("\n"), @@ -148,13 +165,11 @@ const server = createServer(async (request, response) => { console.timeEnd("function-exec"); try { + // We should keep all optional parameters out of this call, so it can work for any model. const stream = await modelsAPI.inference.chat.completions.create({ model: functionCallRes.model, messages: functionCallRes.messages, stream: true, - stream_options: { - include_usage: false, - }, }); console.time("streaming"); diff --git a/src/models-api.ts b/src/models-api.ts index 1db5b68..7ab5d4a 100644 --- a/src/models-api.ts +++ b/src/models-api.ts @@ -2,16 +2,14 @@ import OpenAI from "openai"; // Model is the structure of a model in the model catalog. export interface Model { - id: string; name: string; - friendly_name: string; - model_version: number; + displayName: string; + version: number; publisher: string; - model_family: string; - model_registry: string; + registryName: string; license: string; - task: string; - description: string; + inferenceTasks: string[]; + description?: string; summary: string; } @@ -44,19 +42,23 @@ export class ModelsAPI { } async getModel(modelName: string): Promise { + const modelFromIndex = await this.getModelFromIndex(modelName); + const modelRes = await fetch( - "https://modelcatalog.azure-api.net/v1/model/" + modelName + `https://eastus.api.azureml.ms/asset-gallery/v1.0/${modelFromIndex.registryName}/models/${modelFromIndex.name}/version/${modelFromIndex.version}`, ); if (!modelRes.ok) { - throw new Error(`Failed to fetch ${modelName} from the model catalog.`); + throw new Error(`Failed to fetch ${modelName} details from the model catalog.`); } const model = (await modelRes.json()) as Model; return model; } async getModelSchema(modelName: string): Promise { + const modelFromIndex = await this.getModelFromIndex(modelName); + const modelSchemaRes = await fetch( - `https://modelcatalogcachev2-ebendjczf0c5dzca.b02.azurefd.net/widgets/en/Serverless/${modelName.toLowerCase()}.json` + `https://modelcatalogcachev2-ebendjczf0c5dzca.b02.azurefd.net/widgets/en/Serverless/${modelFromIndex.registryName.toLowerCase()}/${modelFromIndex.name.toLowerCase()}.json` ); if (!modelSchemaRes.ok) { throw new Error( @@ -73,14 +75,36 @@ export class ModelsAPI { } const modelsRes = await fetch( - "https://modelcatalog.azure-api.net/v1/models" + "https://eastus.api.azureml.ms/asset-gallery/v1.0/models", + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + filters: [ + { field: "freePlayground", values: ["true"], operator: "eq" }, + { field: "labels", values: ["latest"], operator: "eq" }, + ], + order: [{ field: "displayName", direction: "Asc" }], + }), + } ); if (!modelsRes.ok) { throw new Error("Failed to fetch models from the model catalog"); } - const models = (await modelsRes.json()) as Model[]; + const models = (await modelsRes.json()).summaries as Model[]; this._models = models; return models; } + + async getModelFromIndex(modelName: string): Promise { + this._models = this._models || (await this.listModels()); + const modelFromIndex = this._models.find((model) => model.name === modelName); + if (!modelFromIndex) { + throw new Error(`Failed to fetch ${modelName} from the model catalog.`); + } + return modelFromIndex; + } }