diff --git a/sdk/openai/openai/sources/customizations/api/deserializers.ts b/sdk/openai/openai/sources/customizations/api/deserializers.ts index 85cad9b8af45..6fcb3a4641c3 100644 --- a/sdk/openai/openai/sources/customizations/api/deserializers.ts +++ b/sdk/openai/openai/sources/customizations/api/deserializers.ts @@ -1,7 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -import { ChatMessage, ChatRole, Completions } from "../../generated/src/models/models.js"; +import { + ChatMessage, + ChatRole, + Completions, + PromptFilterResult, +} from "../../generated/src/models/models.js"; import { ChatChoiceOutput, ChatMessageOutput, @@ -11,22 +16,30 @@ import { } from "../../generated/src/rest/outputModels.js"; import { ChatCompletions } from "../models/models.js"; import { ContentFilterResults } from "./models.js"; + +function getPromptFilterResult(body: Record): { + promptFilterResults?: PromptFilterResult[]; +} { + const res = body["prompt_annotations"] ?? body["prompt_filter_results"]; + return !res + ? {} + : { + promptFilterResults: res.map((p: PromptFilterResultOutput) => ({ + promptIndex: p["prompt_index"], + ...(!p.content_filter_results + ? {} + : { + contentFilterResults: deserializeContentFilter(p.content_filter_results), + }), + })), + }; +} + export function getCompletionsResult(body: Record): Omit { return { id: body["id"], created: new Date(body["created"]), - ...(!body["prompt_annotations"] - ? {} - : { - promptFilterResults: body["prompt_annotations"].map((p: PromptFilterResultOutput) => ({ - promptIndex: p["prompt_index"], - ...(!p.content_filter_results - ? {} - : { - contentFilterResults: deserializeContentFilter(p.content_filter_results), - }), - })), - }), + ...getPromptFilterResult(body), choices: (body["choices"] ?? []).map((p: ChoiceOutput) => ({ text: p["text"], index: p["index"], @@ -62,18 +75,7 @@ export function getChatCompletionsResult(body: Record): ChatComplet ? {} : { contentFilterResults: deserializeContentFilter(p.content_filter_results) }), })), - ...(!body["prompt_annotations"] - ? {} - : { - promptFilterResults: body["prompt_annotations"].map((p: PromptFilterResultOutput) => ({ - promptIndex: p["prompt_index"], - ...(!p.content_filter_results - ? {} - : { - contentFilterResults: deserializeContentFilter(p.content_filter_results), - }), - })), - }), + ...getPromptFilterResult(body), ...(!body["usage"] ? {} : { diff --git a/sdk/openai/openai/src/api/deserializers.ts b/sdk/openai/openai/src/api/deserializers.ts index 2c5736f39ef1..6128f7195811 100644 --- a/sdk/openai/openai/src/api/deserializers.ts +++ b/sdk/openai/openai/src/api/deserializers.ts @@ -9,7 +9,7 @@ * If you need to make changes, please do so in the original source file, \{project-root\}/sources/custom */ -import { ChatMessage, ChatRole, Completions } from "../models/models.js"; +import { ChatMessage, ChatRole, Completions, PromptFilterResult } from "../models/models.js"; import { ChatChoiceOutput, ChatMessageOutput, @@ -20,22 +20,29 @@ import { import { ChatCompletions } from "../models/models.js"; import { ContentFilterResults } from "./models.js"; +function getPromptFilterResult(body: Record): { + promptFilterResults?: PromptFilterResult[]; +} { + const res = body["prompt_annotations"] ?? body["prompt_filter_results"]; + return !res + ? {} + : { + promptFilterResults: res.map((p: PromptFilterResultOutput) => ({ + promptIndex: p["prompt_index"], + ...(!p.content_filter_results + ? {} + : { + contentFilterResults: deserializeContentFilter(p.content_filter_results), + }), + })), + }; +} + export function getCompletionsResult(body: Record): Omit { return { id: body["id"], created: new Date(body["created"]), - ...(!body["prompt_annotations"] - ? {} - : { - promptFilterResults: body["prompt_annotations"].map((p: PromptFilterResultOutput) => ({ - promptIndex: p["prompt_index"], - ...(!p.content_filter_results - ? {} - : { - contentFilterResults: deserializeContentFilter(p.content_filter_results), - }), - })), - }), + ...getPromptFilterResult(body), choices: (body["choices"] ?? []).map((p: ChoiceOutput) => ({ text: p["text"], index: p["index"], @@ -71,18 +78,7 @@ export function getChatCompletionsResult(body: Record): ChatComplet ? {} : { contentFilterResults: deserializeContentFilter(p.content_filter_results) }), })), - ...(!body["prompt_annotations"] - ? {} - : { - promptFilterResults: body["prompt_annotations"].map((p: PromptFilterResultOutput) => ({ - promptIndex: p["prompt_index"], - ...(!p.content_filter_results - ? {} - : { - contentFilterResults: deserializeContentFilter(p.content_filter_results), - }), - })), - }), + ...getPromptFilterResult(body), ...(!body["usage"] ? {} : { diff --git a/sdk/openai/openai/test/internal/deserializer.spec.ts b/sdk/openai/openai/test/internal/deserializer.spec.ts index 8372e6e90a97..4fd92b4738ee 100644 --- a/sdk/openai/openai/test/internal/deserializer.spec.ts +++ b/sdk/openai/openai/test/internal/deserializer.spec.ts @@ -7,6 +7,136 @@ import { getChatCompletionsResult, getCompletionsResult } from "../../src/api/de describe("deserializers", () => { describe("getCompletionsResult", () => { it("should deserialize completions response", () => { + const body = { + id: "123", + created: "2022-01-01T00:00:00.000Z", + prompt_filter_results: [ + { + prompt_index: 0, + content_filter_results: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + self_harm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + choices: [ + { + text: "Hello", + index: 0, + content_filter_results: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + self_harm: { + severity: "low", + filtered: false, + }, + }, + logprobs: { + tokens: ["Hello", "there", "!"], + token_logprobs: [-0.1, -0.2, -0.3], + top_logprobs: [ + { + "1": -0.1, + }, + ], + text_offset: [0, 6, 11], + }, + finish_reason: "stop", + }, + ], + }; + + const result = getCompletionsResult(body); + + assert.deepStrictEqual(result, { + id: "123", + created: new Date("2022-01-01T00:00:00.000Z"), + promptFilterResults: [ + { + promptIndex: 0, + contentFilterResults: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + selfHarm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + choices: [ + { + text: "Hello", + index: 0, + contentFilterResults: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + selfHarm: { + severity: "low", + filtered: false, + }, + }, + logprobs: { + tokens: ["Hello", "there", "!"], + tokenLogprobs: [-0.1, -0.2, -0.3], + topLogprobs: [ + { + "1": -0.1, + }, + ], + textOffset: [0, 6, 11], + }, + finishReason: "stop", + }, + ], + }); + }); + + it("should deserialize completions response with old name for prompt filter results", () => { const body = { id: "123", created: "2022-01-01T00:00:00.000Z", @@ -136,8 +266,127 @@ describe("deserializers", () => { }); }); }); + describe("getChatCompletionsResult", () => { it("should deserialize chat completions result", () => { + const body = { + id: "123", + created: "2022-01-01T00:00:00.000Z", + prompt_filter_results: [ + { + prompt_index: 0, + content_filter_results: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + self_harm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + choices: [ + { + message: { + role: "bot", + content: "Hello", + }, + index: 0, + finish_reason: "stop", + content_filter_results: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + self_harm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + usage: { completion_tokens: 135, prompt_tokens: 68, total_tokens: 203 }, + }; + + const result = getChatCompletionsResult(body); + + assert.deepStrictEqual(result, { + id: "123", + created: new Date("2022-01-01T00:00:00.000Z"), + promptFilterResults: [ + { + promptIndex: 0, + contentFilterResults: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + selfHarm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + choices: [ + { + message: { + role: "bot", + content: "Hello", + }, + index: 0, + finishReason: "stop", + contentFilterResults: { + sexual: { + severity: "low", + filtered: false, + }, + violence: { + severity: "low", + filtered: false, + }, + hate: { + severity: "low", + filtered: false, + }, + selfHarm: { + severity: "low", + filtered: false, + }, + }, + }, + ], + usage: { completionTokens: 135, promptTokens: 68, totalTokens: 203 }, + }); + }); + + it("should deserialize chat completions result with old name for prompt filter results", () => { const body = { id: "123", created: "2022-01-01T00:00:00.000Z", @@ -259,7 +508,7 @@ describe("deserializers", () => { const body = { id: "123", created: "2022-01-01T00:00:00.000Z", - prompt_annotations: [ + prompt_filter_results: [ { prompt_index: 0, content_filter_results: {