From 2b7a68067820c08e9204342c35604d4a8b501c77 Mon Sep 17 00:00:00 2001 From: Deyaaeldeen Almahallawi Date: Thu, 21 Sep 2023 10:42:54 -0700 Subject: [PATCH] [OpenAI] Rename prompt_annotations to prompt_filter_results (#27186) ### Packages impacted by this PR @azure/openai ### Issues associated with this PR https://portal.microsofticm.com/imp/v3/incidents/details/424851016/home ### Describe the problem that is addressed by this PR The API renamed this property last minute ### What are the possible designs available to address the problem? If there are more than one possible design, why was the one in this PR chosen? N/A ### Are there test cases added in this PR? _(If not, why?)_ Yes ### Provide a list of related PRs _(if any)_ N/A ### Command used to generate this PR:**_(Applicable only to SDK release request PRs)_ ### Checklists - [x] Added impacted package name to the issue description - [ ] Does this PR needs any fixes in the SDK Generator?** _(If so, create an Issue in the [Autorest/typescript](https://github.com/Azure/autorest.typescript) repository and link it here)_ - [ ] Added a changelog (if necessary) --- .../customizations/api/deserializers.ts | 52 ++-- sdk/openai/openai/src/api/deserializers.ts | 46 ++-- .../openai/test/internal/deserializer.spec.ts | 251 +++++++++++++++++- 3 files changed, 298 insertions(+), 51 deletions(-) diff --git a/sdk/openai/openai/sources/customizations/api/deserializers.ts b/sdk/openai/openai/sources/customizations/api/deserializers.ts index 85cad9b8af45..6fcb3a4641c3 100644 --- a/sdk/openai/openai/sources/customizations/api/deserializers.ts +++ b/sdk/openai/openai/sources/customizations/api/deserializers.ts @@ -1,7 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -import { ChatMessage, ChatRole, Completions } from "../../generated/src/models/models.js"; +import { + ChatMessage, + ChatRole, + Completions, + PromptFilterResult, +} from "../../generated/src/models/models.js"; import { ChatChoiceOutput, ChatMessageOutput, @@ -11,22 +16,30 @@ import { } from "../../generated/src/rest/outputModels.js"; import { ChatCompletions } from "../models/models.js"; import { ContentFilterResults } from "./models.js"; + +function getPromptFilterResult(body: Record): { + promptFilterResults?: PromptFilterResult[]; +} { + const res = body["prompt_annotations"] ?? body["prompt_filter_results"]; + return !res + ? {} + : { + promptFilterResults: res.map((p: PromptFilterResultOutput) => ({ + promptIndex: p["prompt_index"], + ...(!p.content_filter_results + ? {} + : { + contentFilterResults: deserializeContentFilter(p.content_filter_results), + }), + })), + }; +} + export function getCompletionsResult(body: Record): Omit { return { id: body["id"], created: new Date(body["created"]), - ...(!body["prompt_annotations"] - ? {} - : { - promptFilterResults: body["prompt_annotations"].map((p: PromptFilterResultOutput) => ({ - promptIndex: p["prompt_index"], - ...(!p.content_filter_results - ? {} - : { - contentFilterResults: deserializeContentFilter(p.content_filter_results), - }), - })), - }), + ...getPromptFilterResult(body), choices: (body["choices"] ?? []).map((p: ChoiceOutput) => ({ text: p["text"], index: p["index"], @@ -62,18 +75,7 @@ export function getChatCompletionsResult(body: Record): ChatComplet ? {} : { contentFilterResults: deserializeContentFilter(p.content_filter_results) }), })), - ...(!body["prompt_annotations"] - ? {} - : { - promptFilterResults: body["prompt_annotations"].map((p: PromptFilterResultOutput) => ({ - promptIndex: p["prompt_index"], - ...(!p.content_filter_results - ? {} - : { - contentFilterResults: deserializeContentFilter(p.content_filter_results), - }), - })), - }), + ...getPromptFilterResult(body), ...(!body["usage"] ? {} : { diff --git a/sdk/openai/openai/src/api/deserializers.ts b/sdk/openai/openai/src/api/deserializers.ts index 2c5736f39ef1..6128f7195811 100644 --- a/sdk/openai/openai/src/api/deserializers.ts +++ b/sdk/openai/openai/src/api/deserializers.ts @@ -9,7 +9,7 @@ * If you need to make changes, please do so in the original source file, \{project-root\}/sources/custom */ -import { ChatMessage, ChatRole, Completions } from "../models/models.js"; +import { ChatMessage, ChatRole, Completions, PromptFilterResult } from "../models/models.js"; import { ChatChoiceOutput, ChatMessageOutput, @@ -20,22 +20,29 @@ import { import { ChatCompletions } from "../models/models.js"; import { ContentFilterResults } from "./models.js"; +function getPromptFilterResult(body: Record): { + promptFilterResults?: PromptFilterResult[]; +} { + const res = body["prompt_annotations"] ?? body["prompt_filter_results"]; + return !res + ? {} + : { + promptFilterResults: res.map((p: PromptFilterResultOutput) => ({ + promptIndex: p["prompt_index"], + ...(!p.content_filter_results + ? {} + : { + contentFilterResults: deserializeContentFilter(p.content_filter_results), + }), + })), + }; +} + export function getCompletionsResult(body: Record): Omit { return { id: body["id"], created: new Date(body["created"]), - ...(!body["prompt_annotations"] - ? {} - : { - promptFilterResults: body["prompt_annotations"].map((p: PromptFilterResultOutput) => ({ - promptIndex: p["prompt_index"], - ...(!p.content_filter_results - ? {} - : { - contentFilterResults: deserializeContentFilter(p.content_filter_results), - }), - })), - }), + ...getPromptFilterResult(body), choices: (body["choices"] ?? []).map((p: ChoiceOutput) => ({ text: p["text"], index: p["index"], @@ -71,18 +78,7 @@ export function getChatCompletionsResult(body: Record): ChatComplet ? {} : { contentFilterResults: deserializeContentFilter(p.content_filter_results) }), })), - ...(!body["prompt_annotations"] - ? {} - : { - promptFilterResults: body["prompt_annotations"].map((p: PromptFilterResultOutput) => ({ - promptIndex: p["prompt_index"], - ...(!p.content_filter_results - ? {} - : { - contentFilterResults: deserializeContentFilter(p.content_filter_results), - }), - })), - }), + ...getPromptFilterResult(body), ...(!body["usage"] ? {} : { diff --git a/sdk/openai/openai/test/internal/deserializer.spec.ts b/sdk/openai/openai/test/internal/deserializer.spec.ts index 8372e6e90a97..4fd92b4738ee 100644 --- a/sdk/openai/openai/test/internal/deserializer.spec.ts +++ b/sdk/openai/openai/test/internal/deserializer.spec.ts @@ -7,6 +7,136 @@ import { getChatCompletionsResult, getCompletionsResult } from "../../src/api/de describe("deserializers", () => { describe("getCompletionsResult", () => { it("should deserialize completions response", () => { + const body = { + id: "123", + created: "2022-01-01T00:00:00.000Z", + prompt_filter_results: [ + { + prompt_index: 0, + content_filter_results: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + self_harm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + choices: [ + { + text: "Hello", + index: 0, + content_filter_results: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + self_harm: { + severity: "low", + filtered: false, + }, + }, + logprobs: { + tokens: ["Hello", "there", "!"], + token_logprobs: [-0.1, -0.2, -0.3], + top_logprobs: [ + { + "1": -0.1, + }, + ], + text_offset: [0, 6, 11], + }, + finish_reason: "stop", + }, + ], + }; + + const result = getCompletionsResult(body); + + assert.deepStrictEqual(result, { + id: "123", + created: new Date("2022-01-01T00:00:00.000Z"), + promptFilterResults: [ + { + promptIndex: 0, + contentFilterResults: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + selfHarm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + choices: [ + { + text: "Hello", + index: 0, + contentFilterResults: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + selfHarm: { + severity: "low", + filtered: false, + }, + }, + logprobs: { + tokens: ["Hello", "there", "!"], + tokenLogprobs: [-0.1, -0.2, -0.3], + topLogprobs: [ + { + "1": -0.1, + }, + ], + textOffset: [0, 6, 11], + }, + finishReason: "stop", + }, + ], + }); + }); + + it("should deserialize completions response with old name for prompt filter results", () => { const body = { id: "123", created: "2022-01-01T00:00:00.000Z", @@ -136,8 +266,127 @@ describe("deserializers", () => { }); }); }); + describe("getChatCompletionsResult", () => { it("should deserialize chat completions result", () => { + const body = { + id: "123", + created: "2022-01-01T00:00:00.000Z", + prompt_filter_results: [ + { + prompt_index: 0, + content_filter_results: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + self_harm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + choices: [ + { + message: { + role: "bot", + content: "Hello", + }, + index: 0, + finish_reason: "stop", + content_filter_results: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + self_harm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + usage: { completion_tokens: 135, prompt_tokens: 68, total_tokens: 203 }, + }; + + const result = getChatCompletionsResult(body); + + assert.deepStrictEqual(result, { + id: "123", + created: new Date("2022-01-01T00:00:00.000Z"), + promptFilterResults: [ + { + promptIndex: 0, + contentFilterResults: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + selfHarm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + choices: [ + { + message: { + role: "bot", + content: "Hello", + }, + index: 0, + finishReason: "stop", + contentFilterResults: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + selfHarm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + usage: { completionTokens: 135, promptTokens: 68, totalTokens: 203 }, + }); + }); + + it("should deserialize chat completions result with old name for prompt filter results", () => { const body = { id: "123", created: "2022-01-01T00:00:00.000Z", @@ -259,7 +508,7 @@ describe("deserializers", () => { const body = { id: "123", created: "2022-01-01T00:00:00.000Z", - prompt_annotations: [ + prompt_filter_results: [ { prompt_index: 0, content_filter_results: {