From c22b9828a6551ea5205a025f1c5ace4046bdc66c Mon Sep 17 00:00:00 2001 From: Pierre Gayvallet Date: Fri, 24 Jan 2025 13:50:19 +0100 Subject: [PATCH] [inference] add support for `auto` function calling mode (#208144) ## Summary Fix https://github.com/elastic/kibana/issues/208143 Add a new value for the `functionCalling` parameter, `auto`, which is the new default. When `functionCalling=auto`, the system will detect if the underlying model/provider supports native function calling, and otherwise automatically fallback to simulated function calling. --- .../inference-common/src/chat_complete/api.ts | 9 +-- .../inference-common/src/output/api.ts | 2 +- .../inference/inference_adapter.test.mocks.ts | 16 +++++ .../inference/inference_adapter.test.ts | 50 +++++++++++---- .../adapters/inference/inference_adapter.ts | 12 ++-- .../openai/openai_adapter.test.mocks.ts | 16 +++++ .../adapters/openai/openai_adapter.test.ts | 35 ++++++++-- .../adapters/openai/openai_adapter.ts | 12 ++-- .../chat_complete/adapters/openai/types.ts | 8 +++ .../utils/function_calling_support.test.ts | 64 +++++++++++++++++++ .../utils/function_calling_support.ts | 26 ++++++++ .../server/chat_complete/utils/index.ts | 1 + .../inference/server/routes/chat_complete.ts | 2 +- 13 files changed, 221 insertions(+), 32 deletions(-) create mode 100644 x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.mocks.ts create mode 100644 x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.mocks.ts create mode 100644 x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.test.ts create mode 100644 x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.ts diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/api.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/api.ts index 15596c32066e1..155de9b286c9b 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/api.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/api.ts @@ -102,7 +102,7 @@ export type ChatCompleteOptions< */ modelName?: string; /** - * Function calling mode, defaults to "native". + * Function calling mode, defaults to "auto". */ functionCalling?: FunctionCallingMode; /** @@ -152,7 +152,8 @@ export interface ChatCompleteResponse { + const actual = jest.requireActual('../../utils/function_calling_support'); + return { + ...actual, + isNativeFunctionCallingSupported: isNativeFunctionCallingSupportedMock, + }; +}); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.ts index 68733c8f4dae2..65558d68bc90c 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.ts @@ -5,13 +5,14 @@ * 2.0. */ +import { isNativeFunctionCallingSupportedMock } from './inference_adapter.test.mocks'; import OpenAI from 'openai'; import { v4 } from 'uuid'; import { PassThrough } from 'stream'; import { lastValueFrom, Subject, toArray, filter } from 'rxjs'; -import type { Logger } from '@kbn/logging'; import { loggerMock } from '@kbn/logging-mocks'; import { + ToolChoiceType, ChatCompletionEventType, MessageRole, isChatCompletionChunkEvent, @@ -48,21 +49,23 @@ function createOpenAIChunk({ describe('inferenceAdapter', () => { const executorMock = { + getConnector: jest.fn(), invoke: jest.fn(), - } as InferenceExecutor & { invoke: jest.MockedFn }; + } as InferenceExecutor & { + invoke: jest.MockedFn; + getConnector: jest.MockedFn; + }; - const logger = { - debug: jest.fn(), - error: jest.fn(), - } as unknown as Logger; + const logger = loggerMock.create(); beforeEach(() => { executorMock.invoke.mockReset(); + isNativeFunctionCallingSupportedMock.mockReset().mockReturnValue(true); }); const defaultArgs = { executor: executorMock, - logger: loggerMock.create(), + logger, }; describe('when creating the request', () => { @@ -232,6 +235,25 @@ describe('inferenceAdapter', () => { ]); }); + it('propagates the temperature parameter', () => { + inferenceAdapter.chatComplete({ + logger, + executor: executorMock, + messages: [{ role: MessageRole.User, content: 'question' }], + temperature: 0.4, + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + expect(executorMock.invoke).toHaveBeenCalledWith({ + subAction: 'unified_completion_stream', + subActionParams: expect.objectContaining({ + body: expect.objectContaining({ + temperature: 0.4, + }), + }), + }); + }); + it('propagates the abort signal when provided', () => { const abortController = new AbortController(); @@ -251,20 +273,26 @@ describe('inferenceAdapter', () => { }); }); - it('propagates the temperature parameter', () => { + it('uses the right value for functionCalling=auto', () => { + isNativeFunctionCallingSupportedMock.mockReturnValue(false); + inferenceAdapter.chatComplete({ logger, executor: executorMock, messages: [{ role: MessageRole.User, content: 'question' }], - temperature: 0.4, + tools: { + foo: { description: 'my tool' }, + }, + toolChoice: ToolChoiceType.auto, + functionCalling: 'auto', }); expect(executorMock.invoke).toHaveBeenCalledTimes(1); expect(executorMock.invoke).toHaveBeenCalledWith({ subAction: 'unified_completion_stream', subActionParams: expect.objectContaining({ - body: expect.objectContaining({ - temperature: 0.4, + body: expect.not.objectContaining({ + tools: expect.any(Array), }), }), }); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.ts index 168b0f9cf2fb4..e220e3ebf5f8e 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.ts @@ -15,6 +15,7 @@ import { parseInlineFunctionCalls, wrapWithSimulatedFunctionCalling, } from '../../simulated_function_calling'; +import { isNativeFunctionCallingSupported } from '../../utils/function_calling_support'; import { toolsToOpenAI, toolChoiceToOpenAI, @@ -30,16 +31,19 @@ export const inferenceAdapter: InferenceConnectorAdapter = { messages, toolChoice, tools, - functionCalling, + functionCalling = 'auto', temperature = 0, modelName, logger, abortSignal, }) => { - const simulatedFunctionCalling = functionCalling === 'simulated'; + const useSimulatedFunctionCalling = + functionCalling === 'auto' + ? !isNativeFunctionCallingSupported(executor.getConnector()) + : functionCalling === 'simulated'; let request: Omit & { model?: string }; - if (simulatedFunctionCalling) { + if (useSimulatedFunctionCalling) { const wrapped = wrapWithSimulatedFunctionCalling({ system, messages, @@ -87,7 +91,7 @@ export const inferenceAdapter: InferenceConnectorAdapter = { }), processOpenAIStream(), emitTokenCountEstimateIfMissing({ request }), - simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity + useSimulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity ); }, }; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.mocks.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.mocks.ts new file mode 100644 index 0000000000000..625a06ecf3515 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.mocks.ts @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export const isNativeFunctionCallingSupportedMock = jest.fn(); + +jest.doMock('../../utils/function_calling_support', () => { + const actual = jest.requireActual('../../utils/function_calling_support'); + return { + ...actual, + isNativeFunctionCallingSupported: isNativeFunctionCallingSupportedMock, + }; +}); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts index c1ef52a3bc241..c620f1d01bf7f 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts @@ -5,14 +5,15 @@ * 2.0. */ +import { isNativeFunctionCallingSupportedMock } from './openai_adapter.test.mocks'; import OpenAI from 'openai'; import { v4 } from 'uuid'; import { PassThrough } from 'stream'; import { pick } from 'lodash'; import { lastValueFrom, Subject, toArray, filter } from 'rxjs'; -import type { Logger } from '@kbn/logging'; import { loggerMock } from '@kbn/logging-mocks'; import { + ToolChoiceType, ChatCompletionEventType, isChatCompletionChunkEvent, MessageRole, @@ -48,21 +49,23 @@ function createOpenAIChunk({ describe('openAIAdapter', () => { const executorMock = { + getConnector: jest.fn(), invoke: jest.fn(), - } as InferenceExecutor & { invoke: jest.MockedFn }; + } as InferenceExecutor & { + invoke: jest.MockedFn; + getConnector: jest.MockedFn; + }; - const logger = { - debug: jest.fn(), - error: jest.fn(), - } as unknown as Logger; + const logger = loggerMock.create(); beforeEach(() => { executorMock.invoke.mockReset(); + isNativeFunctionCallingSupportedMock.mockReset().mockReturnValue(true); }); const defaultArgs = { executor: executorMock, - logger: loggerMock.create(), + logger, }; describe('when creating the request', () => { @@ -359,6 +362,24 @@ describe('openAIAdapter', () => { }); }); + it('uses the right value for functionCalling=auto', () => { + isNativeFunctionCallingSupportedMock.mockReturnValue(false); + + openAIAdapter.chatComplete({ + logger, + executor: executorMock, + messages: [{ role: MessageRole.User, content: 'question' }], + tools: { + foo: { description: 'my tool' }, + }, + toolChoice: ToolChoiceType.auto, + functionCalling: 'auto', + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + expect(getRequest().body.tools).toBeUndefined(); + }); + it('propagates the temperature parameter', () => { openAIAdapter.chatComplete({ logger, diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts index a8abec5f43204..83b1a47131bbd 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.ts @@ -14,6 +14,7 @@ import { parseInlineFunctionCalls, wrapWithSimulatedFunctionCalling, } from '../../simulated_function_calling'; +import { isNativeFunctionCallingSupported } from '../../utils/function_calling_support'; import type { OpenAIRequest } from './types'; import { messagesToOpenAI, toolsToOpenAI, toolChoiceToOpenAI } from './to_openai'; import { processOpenAIStream } from './process_openai_stream'; @@ -27,15 +28,18 @@ export const openAIAdapter: InferenceConnectorAdapter = { toolChoice, tools, temperature = 0, - functionCalling, + functionCalling = 'auto', modelName, logger, abortSignal, }) => { - const simulatedFunctionCalling = functionCalling === 'simulated'; + const useSimulatedFunctionCalling = + functionCalling === 'auto' + ? !isNativeFunctionCallingSupported(executor.getConnector()) + : functionCalling === 'simulated'; let request: OpenAIRequest; - if (simulatedFunctionCalling) { + if (useSimulatedFunctionCalling) { const wrapped = wrapWithSimulatedFunctionCalling({ system, messages, @@ -86,7 +90,7 @@ export const openAIAdapter: InferenceConnectorAdapter = { }), processOpenAIStream(), emitTokenCountEstimateIfMissing({ request }), - simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity + useSimulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity ); }, }; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/types.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/types.ts index 15e5dbc3684b1..956c0366ee1fd 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/types.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/types.ts @@ -8,3 +8,11 @@ import type OpenAI from 'openai'; export type OpenAIRequest = Omit & { model?: string }; + +// duplicated from x-pack/platform/plugins/shared/stack_connectors/common/openai/constants.ts +// because depending on stack_connectors from the inference plugin creates a cyclic dependency... +export enum OpenAiProviderType { + OpenAi = 'OpenAI', + AzureAi = 'Azure OpenAI', + Other = 'Other', +} diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.test.ts new file mode 100644 index 0000000000000..a3723309a242c --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.test.ts @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { OpenAiProviderType } from '../adapters/openai/types'; +import { InferenceConnector, InferenceConnectorType } from '@kbn/inference-common'; +import { isNativeFunctionCallingSupported } from './function_calling_support'; + +const createConnector = ( + parts: Partial & Pick +): InferenceConnector => { + return { + connectorId: 'connector-id', + name: 'my connector', + config: {}, + ...parts, + }; +}; + +describe('isNativeFunctionCallingSupported', () => { + it('returns true for gemini connector', () => { + const connector = createConnector({ type: InferenceConnectorType.Gemini }); + expect(isNativeFunctionCallingSupported(connector)).toBe(true); + }); + + it('returns true for bedrock connector', () => { + const connector = createConnector({ type: InferenceConnectorType.Bedrock }); + expect(isNativeFunctionCallingSupported(connector)).toBe(true); + }); + + it('returns true for inference connector', () => { + const connector = createConnector({ type: InferenceConnectorType.Inference }); + expect(isNativeFunctionCallingSupported(connector)).toBe(true); + }); + + describe('openAI connector', () => { + it('returns true for "OpenAI" provider', () => { + const connector = createConnector({ + type: InferenceConnectorType.OpenAI, + config: { apiProvider: OpenAiProviderType.OpenAi }, + }); + expect(isNativeFunctionCallingSupported(connector)).toBe(true); + }); + + it('returns true for "Azure" provider', () => { + const connector = createConnector({ + type: InferenceConnectorType.OpenAI, + config: { apiProvider: OpenAiProviderType.AzureAi }, + }); + expect(isNativeFunctionCallingSupported(connector)).toBe(true); + }); + + it('returns false for "Other" provider', () => { + const connector = createConnector({ + type: InferenceConnectorType.OpenAI, + config: { apiProvider: OpenAiProviderType.Other }, + }); + expect(isNativeFunctionCallingSupported(connector)).toBe(false); + }); + }); +}); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.ts new file mode 100644 index 0000000000000..7e70d417a2996 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.ts @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { InferenceConnector, InferenceConnectorType } from '@kbn/inference-common'; +import { OpenAiProviderType } from '../adapters/openai/types'; + +export const isNativeFunctionCallingSupported = (connector: InferenceConnector): boolean => { + switch (connector.type) { + case InferenceConnectorType.OpenAI: + const apiProvider = + (connector.config.apiProvider as OpenAiProviderType) ?? OpenAiProviderType.Other; + return apiProvider !== OpenAiProviderType.Other; + case InferenceConnectorType.Inference: + // note: later we might need to check the provider type, for now let's assume support + // will be handled by ES and that all providers will support native FC. + return true; + case InferenceConnectorType.Bedrock: + return true; + case InferenceConnectorType.Gemini: + return true; + } +}; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/index.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/index.ts index 12256630bd741..70322d0af00bd 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/index.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/index.ts @@ -15,3 +15,4 @@ export { chunksIntoMessage } from './chunks_into_message'; export { streamToResponse } from './stream_to_response'; export { handleCancellation } from './handle_cancellation'; export { mergeChunks } from './merge_chunks'; +export { isNativeFunctionCallingSupported } from './function_calling_support'; diff --git a/x-pack/platform/plugins/shared/inference/server/routes/chat_complete.ts b/x-pack/platform/plugins/shared/inference/server/routes/chat_complete.ts index 87ace0b2b7cc6..e4292c0af89da 100644 --- a/x-pack/platform/plugins/shared/inference/server/routes/chat_complete.ts +++ b/x-pack/platform/plugins/shared/inference/server/routes/chat_complete.ts @@ -85,7 +85,7 @@ const chatCompleteBodySchema: Type = schema.object({ ]) ), functionCalling: schema.maybe( - schema.oneOf([schema.literal('native'), schema.literal('simulated')]) + schema.oneOf([schema.literal('native'), schema.literal('simulated'), schema.literal('auto')]) ), temperature: schema.maybe(schema.number()), modelName: schema.maybe(schema.string()),