Skip to content

Commit

Permalink
Use updated models metadata API, include o1 models, and make more res…
Browse files Browse the repository at this point in the history
…ilient to wrong model names
  • Loading branch information
sgoedecke committed Sep 16, 2024
1 parent 96014d6 commit a3dca08
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 26 deletions.
14 changes: 8 additions & 6 deletions src/functions/describe-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ export class describeModel extends Tool {
properties: {
model: {
type: "string",
description:
'The model to describe. Looks like "registry/model-name". For example, `azureml/Phi-3-medium-128k-instruct` or `azure-openai/gpt-4o',
description: [
'The model to describe. Looks like "model-name". For example, `Phi-3-medium-128k-instruct` or `gpt-4o`.',
'The list of models is available in the context window of the chat, in the `<-- LIST OF MODELS -->` section.',
'If the model name is not found in the list of models, pick the closest matching model from the list.',
].join("\n"),
},
},
required: ["model"],
Expand All @@ -30,12 +33,11 @@ export class describeModel extends Tool {
const systemMessage = [
"The user is asking about the AI model with the following details:",
`\tModel Name: ${model.name}`,
`\tModel Version: ${model.model_version}`,
`\tModel Version: ${model.version}`,
`\tPublisher: ${model.publisher}`,
`\tModel Family: ${model.model_family}`,
`\tModel Registry: ${model.model_registry}`,
`\tModel Registry: ${model.registryName}`,
`\tLicense: ${model.license}`,
`\tTask: ${model.task}`,
`\tTask: ${model.inferenceTasks.join(", ")}`,
`\tDescription: ${model.description}`,
`\tSummary: ${model.summary}`,
"\n",
Expand Down
1 change: 1 addition & 0 deletions src/functions/execute-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Example Queries (IMPORTANT: Phrasing doesn't have to match):
"The name of the model to execute. It is ONLY the name of the model, not the publisher or registry.",
"For example: `gpt-4o`, or `cohere-command-r-plus`.",
"The list of models is available in the context window of the chat, in the `<-- LIST OF MODELS -->` section.",
"If the model name is not found in the list of models, pick the closest matching model from the list.",
].join("\n"),
},
instruction: {
Expand Down
2 changes: 1 addition & 1 deletion src/functions/list-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export class listModels extends Tool {
"That list of models is as follows:",
JSON.stringify(
models.map((model) => ({
name: model.friendly_name,
name: model.displayName,
publisher: model.publisher,
description: model.summary,
}))
Expand Down
29 changes: 22 additions & 7 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ const server = createServer(async (request, response) => {

// List of functions that are available to be called
const modelsAPI = new ModelsAPI(apiKey);

const functions = [listModels, describeModel, executeModel, recommendModel];

// Use the Copilot API to determine which function to execute
Expand All @@ -66,6 +67,7 @@ const server = createServer(async (request, response) => {
// Prepend a system message that includes the list of models, so that
// tool calls can better select the right model to use.
const models = await modelsAPI.listModels();

const toolCallMessages = [
{
role: "system" as const,
Expand All @@ -75,13 +77,28 @@ const server = createServer(async (request, response) => {
"Here is a list of some of the models available to the user:",
"<-- LIST OF MODELS -->",
JSON.stringify(
models.map((model) => ({
friendly_name: model.friendly_name,
[...models.map((model) => ({
friendly_name: model.displayName,
name: model.name,
publisher: model.publisher,
registry: model.model_registry,
registry: model.registryName,
description: model.summary,
}))
})),
{
friendly_name: "OpenAI o1-mini",
name: "o1-mini",
publisher: "openai",
model_registry: "azure-openai",
description: "Smaller, faster, and 80% cheaper than o1-preview, performs well at code generation and small context operations."
},
{
friendly_name: "OpenAI o1-preview",
name: "o1-preview",
publisher: "openai",
model_registry: "azure-openai",
description: "Focused on advanced reasoning and solving complex problems, including math and science tasks. Ideal for applications that require deep contextual understanding and agentic workflows."
},
]
),
"<-- END OF LIST OF MODELS -->",
].join("\n"),
Expand Down Expand Up @@ -148,13 +165,11 @@ const server = createServer(async (request, response) => {
console.timeEnd("function-exec");

try {
// We should keep all optional parameters out of this call, so it can work for any model.
const stream = await modelsAPI.inference.chat.completions.create({
model: functionCallRes.model,
messages: functionCallRes.messages,
stream: true,
stream_options: {
include_usage: false,
},
});

console.time("streaming");
Expand Down
48 changes: 36 additions & 12 deletions src/models-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ import OpenAI from "openai";

// Model is the structure of a model in the model catalog.
export interface Model {
id: string;
name: string;
friendly_name: string;
model_version: number;
displayName: string;
version: number;
publisher: string;
model_family: string;
model_registry: string;
registryName: string;
license: string;
task: string;
description: string;
inferenceTasks: string[];
description?: string;
summary: string;
}

Expand Down Expand Up @@ -44,19 +42,23 @@ export class ModelsAPI {
}

async getModel(modelName: string): Promise<Model> {
const modelFromIndex = await this.getModelFromIndex(modelName);

const modelRes = await fetch(
"https://modelcatalog.azure-api.net/v1/model/" + modelName
`https://eastus.api.azureml.ms/asset-gallery/v1.0/${modelFromIndex.registryName}/models/${modelFromIndex.name}/version/${modelFromIndex.version}`,
);
if (!modelRes.ok) {
throw new Error(`Failed to fetch ${modelName} from the model catalog.`);
throw new Error(`Failed to fetch ${modelName} details from the model catalog.`);
}
const model = (await modelRes.json()) as Model;
return model;
}

async getModelSchema(modelName: string): Promise<ModelSchema> {
const modelFromIndex = await this.getModelFromIndex(modelName);

const modelSchemaRes = await fetch(
`https://modelcatalogcachev2-ebendjczf0c5dzca.b02.azurefd.net/widgets/en/Serverless/${modelName.toLowerCase()}.json`
`https://modelcatalogcachev2-ebendjczf0c5dzca.b02.azurefd.net/widgets/en/Serverless/${modelFromIndex.registryName.toLowerCase()}/${modelFromIndex.name.toLowerCase()}.json`
);
if (!modelSchemaRes.ok) {
throw new Error(
Expand All @@ -73,14 +75,36 @@ export class ModelsAPI {
}

const modelsRes = await fetch(
"https://modelcatalog.azure-api.net/v1/models"
"https://eastus.api.azureml.ms/asset-gallery/v1.0/models",
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
filters: [
{ field: "freePlayground", values: ["true"], operator: "eq" },
{ field: "labels", values: ["latest"], operator: "eq" },
],
order: [{ field: "displayName", direction: "Asc" }],
}),
}
);
if (!modelsRes.ok) {
throw new Error("Failed to fetch models from the model catalog");
}

const models = (await modelsRes.json()) as Model[];
const models = (await modelsRes.json()).summaries as Model[];
this._models = models;
return models;
}

async getModelFromIndex(modelName: string): Promise<Model> {
this._models = this._models || (await this.listModels());
const modelFromIndex = this._models.find((model) => model.name === modelName);
if (!modelFromIndex) {
throw new Error(`Failed to fetch ${modelName} from the model catalog.`);
}
return modelFromIndex;
}
}

0 comments on commit a3dca08

Please sign in to comment.