From 8f7221a2e2f5424783b97e33fe93218e5a167cea Mon Sep 17 00:00:00 2001 From: Saul-Mirone <10047788+Saul-Mirone@users.noreply.github.com> Date: Thu, 21 Mar 2024 07:50:40 +0000 Subject: [PATCH] feat(blocks): support real abort for copilot (#6530) --- .../src/_common/copilot/model/chat-history.ts | 13 +- .../_common/copilot/model/message-schema.ts | 14 +- .../copilot/model/message-type/html/index.ts | 3 +- .../model/message-type/mind-map/index.ts | 3 +- .../model/message-type/text/actions.ts | 4 +- .../src/_common/copilot/service/llama2.ts | 82 +++++------ .../src/_common/copilot/service/open-ai.ts | 129 ++++++++++-------- .../_common/copilot/service/service-base.ts | 12 +- 8 files changed, 141 insertions(+), 119 deletions(-) diff --git a/packages/blocks/src/_common/copilot/model/chat-history.ts b/packages/blocks/src/_common/copilot/model/chat-history.ts index 0e4ccbf62593..e9a1e91023d2 100644 --- a/packages/blocks/src/_common/copilot/model/chat-history.ts +++ b/packages/blocks/src/_common/copilot/model/chat-history.ts @@ -7,11 +7,11 @@ import { html, type TemplateResult } from 'lit'; import { customElement, property } from 'lit/decorators.js'; import { repeat } from 'lit/directives/repeat.js'; +import type { CopilotServiceResult } from '../service/service-base.js'; import type { ApiData, ChatMessage, MessageContent, - MessageContext, MessageSchema, UserChatMessage, } from './message-schema.js'; @@ -19,7 +19,7 @@ import { MessageSchemas } from './message-type/index.js'; export type CopilotAction = { type: string; - run: (context: MessageContext) => AsyncIterable; + run: CopilotServiceResult; }; export interface HistoryItem { @@ -77,9 +77,12 @@ export class AssistantHistoryItem this.stop(); const abortController = new AbortController(); this.abortController = abortController; - const result = this.action.run({ - history: this.history.flatMap(v => v.toContext()), - }); + const result = this.action.run( + { + history: this.history.flatMap(v => v.toContext()), + }, + abortController.signal + ); const process = async () => { let lastValue: Result | undefined; for await (const value of result) { diff --git a/packages/blocks/src/_common/copilot/model/message-schema.ts b/packages/blocks/src/_common/copilot/model/message-schema.ts index 76ed09b9de4c..339ddd8ef5a1 100644 --- a/packages/blocks/src/_common/copilot/model/message-schema.ts +++ b/packages/blocks/src/_common/copilot/model/message-schema.ts @@ -1,6 +1,7 @@ import type { EditorHost } from '@blocksuite/block-std'; import type { TemplateResult } from 'lit'; +import type { CopilotServiceResult } from '../service/service-base.js'; import type { CopilotAction } from './chat-history.js'; export type MessageContent = @@ -28,12 +29,13 @@ export type AssistantChatMessage = { content: string; sources: BackgroundSource[]; }; +export type SystemChatMessage = { + role: 'system'; + content: string; +}; export type ChatMessage = | UserChatMessage - | { - role: 'system'; - content: string; - } + | SystemChatMessage | AssistantChatMessage; export type ApiData = @@ -64,7 +66,7 @@ export const createMessageSchema = ( config: MessageSchema ): MessageSchema & { createActionBuilder: ( - fn: (arg: Arg, context: MessageContext) => AsyncIterable + fn: (arg: Arg) => CopilotServiceResult ) => (arg: Arg) => CopilotAction; } => { return { @@ -72,7 +74,7 @@ export const createMessageSchema = ( createActionBuilder: fn => arg => { return { type: config.type, - run: context => fn(arg, context), + run: fn(arg), }; }, }; diff --git a/packages/blocks/src/_common/copilot/model/message-type/html/index.ts b/packages/blocks/src/_common/copilot/model/message-type/html/index.ts index 254f7e0d7c7c..4ede445fa54e 100644 --- a/packages/blocks/src/_common/copilot/model/message-type/html/index.ts +++ b/packages/blocks/src/_common/copilot/model/message-type/html/index.ts @@ -123,9 +123,8 @@ When sent new wireframes, respond ONLY with the contents of the html file.`, ); export const createHTMLFromTextAction = HTMLMessageSchema.createActionBuilder( - (text: string, context) => { + (text: string) => { return chatService().chat([ - ...context.history, userText( `You are a professional web developer who specializes in building working website prototypes from product requirement descriptions. Your job is to take a product requirement description, then create a working prototype using HTML, CSS, and JavaScript, and finally send the result back. diff --git a/packages/blocks/src/_common/copilot/model/message-type/mind-map/index.ts b/packages/blocks/src/_common/copilot/model/message-type/mind-map/index.ts index baccb988eaa4..0d0939ca03ae 100644 --- a/packages/blocks/src/_common/copilot/model/message-type/mind-map/index.ts +++ b/packages/blocks/src/_common/copilot/model/message-type/mind-map/index.ts @@ -27,9 +27,8 @@ export const MindMapMessageSchema = createMessageSchema({ }); export const createMindMapAction = MindMapMessageSchema.createActionBuilder( - (text: string, context) => { + (text: string) => { return chatService().chat([ - ...context.history, userText( `Use the nested unordered list syntax in Markdown to create a structure similar to a mind map. Analyze the following questions: diff --git a/packages/blocks/src/_common/copilot/model/message-type/text/actions.ts b/packages/blocks/src/_common/copilot/model/message-type/text/actions.ts index 1ce5cb1bdf8f..06513a734f3a 100644 --- a/packages/blocks/src/_common/copilot/model/message-type/text/actions.ts +++ b/packages/blocks/src/_common/copilot/model/message-type/text/actions.ts @@ -2,8 +2,8 @@ import { chatService, userText } from '../utils.js'; import { TextMessageSchema } from './index.js'; export const createCommonTextAction = TextMessageSchema.createActionBuilder( - (text: string, context) => { - return chatService().chat([...context.history, userText(text)]); + (text: string) => { + return chatService().chat([userText(text)]); } ); export const createChangeToneAction = TextMessageSchema.createActionBuilder( diff --git a/packages/blocks/src/_common/copilot/service/llama2.ts b/packages/blocks/src/_common/copilot/service/llama2.ts index c6ba2bcdf1f3..fa062c223a01 100644 --- a/packages/blocks/src/_common/copilot/service/llama2.ts +++ b/packages/blocks/src/_common/copilot/service/llama2.ts @@ -35,22 +35,24 @@ export const llama2Vendor = createVendor<{ TextServiceKind.implService({ name: 'llama2', method: data => ({ - generateText: async messages => { - const result: { - message: { - role: string; - content: string; - }; - } = await fetch(`${data.host}/api/chat`, { - method: 'POST', - body: JSON.stringify({ - model: 'llama2', - messages: messages, - stream: false, - }), - }).then(res => res.json()); - return result.message.content; - }, + generateText: messages => + async function* (context, signal) { + const result: { + message: { + role: string; + content: string; + }; + } = await fetch(`${data.host}/api/chat`, { + method: 'POST', + signal, + body: JSON.stringify({ + model: 'llama2', + messages: [...context.history, ...messages], + stream: false, + }), + }).then(res => res.json()); + yield result.message.content; + }, }), vendor: llama2Vendor, }); @@ -58,28 +60,30 @@ TextServiceKind.implService({ ChatServiceKind.implService({ name: 'llama2', method: data => ({ - chat: messages => { - const llama2Messages = messages.map(message => { - if (message.role === 'user') { - let text = ''; - const imgs: string[] = []; - message.content.forEach(v => { - if (v.type === 'text') { - text += `${v.text}\n`; - } - if (v.type === 'image_url') { - imgs.push(v.image_url.url.split(',')[1]); + chat: messages => + async function* (context, signal) { + const llama2Messages = [...context.history, ...messages].map( + message => { + if (message.role === 'user') { + let text = ''; + const imgs: string[] = []; + message.content.forEach(v => { + if (v.type === 'text') { + text += `${v.text}\n`; + } + if (v.type === 'image_url') { + imgs.push(v.image_url.url.split(',')[1]); + } + }); + return { + role: message.role, + content: text, + images: imgs, + }; } - }); - return { - role: message.role, - content: text, - images: imgs, - }; - } - return message; - }); - return (async function* () { + return message; + } + ); const result: { message: { role: string; @@ -87,6 +91,7 @@ ChatServiceKind.implService({ }; } = await fetch(`${data.host}/api/chat`, { method: 'POST', + signal, body: JSON.stringify({ model: 'llama2', messages: llama2Messages, @@ -94,8 +99,7 @@ ChatServiceKind.implService({ }), }).then(res => res.json()); yield result.message.content; - })(); - }, + }, }), vendor: llama2Vendor, }); diff --git a/packages/blocks/src/_common/copilot/service/open-ai.ts b/packages/blocks/src/_common/copilot/service/open-ai.ts index e091e1b774b8..6e1715dc28aa 100644 --- a/packages/blocks/src/_common/copilot/service/open-ai.ts +++ b/packages/blocks/src/_common/copilot/service/open-ai.ts @@ -58,27 +58,6 @@ const toGPTMessages = ( }); }; -const askGPT = async ( - apiKey: string, - model: - | 'gpt-4' - | 'gpt-3.5-turbo-1106' - | 'gpt-4-vision-preview' - | 'gpt-4-turbo', - messages: Array -) => { - const openai = new OpenAI({ - apiKey: apiKey, - dangerouslyAllowBrowser: true, - }); - const result = await openai.chat.completions.create({ - messages: toGPTMessages(messages), - model: model, - temperature: 0, - max_tokens: 4096, - }); - return result.choices[0].message; -}; const askGPTStream = async function* ( apiKey: string, model: @@ -86,19 +65,23 @@ const askGPTStream = async function* ( | 'gpt-3.5-turbo-1106' | 'gpt-4-vision-preview' | 'gpt-4-turbo', - messages: Array + messages: Array, + signal: AbortSignal ): AsyncIterable { const openai = new OpenAI({ apiKey: apiKey, dangerouslyAllowBrowser: true, }); - const result = await openai.chat.completions.create({ - stream: true, - messages: toGPTMessages(messages), - model: model, - temperature: 0, - max_tokens: 4096, - }); + const result = await openai.chat.completions.create( + { + stream: true, + messages: toGPTMessages(messages), + model: model, + temperature: 0, + max_tokens: 4096, + }, + { signal } + ); let text = ''; for await (const message of result) { text += message.choices[0].delta.content ?? ''; @@ -109,9 +92,13 @@ const askGPTStream = async function* ( TextServiceKind.implService({ name: 'GPT3.5 Turbo', method: data => ({ - generateText: async messages => { - const result = await askGPT(data.apiKey, 'gpt-3.5-turbo-1106', messages); - return result.content ?? ''; + generateText: messages => (context, signal) => { + return askGPTStream( + data.apiKey, + 'gpt-3.5-turbo-1106', + [...context.history, ...messages], + signal + ); }, }), vendor: openaiVendor, @@ -119,9 +106,13 @@ TextServiceKind.implService({ TextServiceKind.implService({ name: 'GPT4', method: data => ({ - generateText: async messages => { - const result = await askGPT(data.apiKey, 'gpt-4', messages); - return result.content ?? ''; + generateText: messages => (context, signal) => { + return askGPTStream( + data.apiKey, + 'gpt-4', + [...context.history, ...messages], + signal + ); }, }), vendor: openaiVendor, @@ -130,8 +121,13 @@ TextServiceKind.implService({ ChatServiceKind.implService({ name: 'GPT3.5 Turbo', method: data => ({ - chat: messages => { - return askGPTStream(data.apiKey, 'gpt-3.5-turbo-1106', messages); + chat: messages => (context, signal) => { + return askGPTStream( + data.apiKey, + 'gpt-3.5-turbo-1106', + [...context.history, ...messages], + signal + ); }, }), vendor: openaiVendor, @@ -140,15 +136,26 @@ ChatServiceKind.implService({ ChatServiceKind.implService({ name: 'GPT4', method: data => ({ - chat: messages => askGPTStream(data.apiKey, 'gpt-4', messages), + chat: messages => (context, signal) => + askGPTStream( + data.apiKey, + 'gpt-4', + [...context.history, ...messages], + signal + ), }), vendor: openaiVendor, }); ChatServiceKind.implService({ name: 'GPT4-Vision', method: data => ({ - chat: messages => - askGPTStream(data.apiKey, 'gpt-4-vision-preview', messages), + chat: messages => (context, signal) => + askGPTStream( + data.apiKey, + 'gpt-4-vision-preview', + [...context.history, ...messages], + signal + ), }), vendor: openaiVendor, }); @@ -200,25 +207,29 @@ EmbeddingServiceKind.implService({ Image2TextServiceKind.implService({ name: 'GPT4 Vision', method: data => ({ - async *generateText(messages) { - const apiKey = data.apiKey; - const openai = new OpenAI({ - apiKey: apiKey, - dangerouslyAllowBrowser: true, - }); - const result = await openai.chat.completions.create({ - messages, - model: 'gpt-4-vision-preview', - temperature: 0, - max_tokens: 4096, - stream: true, - }); - let text = ''; - for await (const message of result) { - text += message.choices[0].delta.content ?? ''; - yield text; - } - }, + generateText: messages => + async function* (context, signal) { + const apiKey = data.apiKey; + const openai = new OpenAI({ + apiKey: apiKey, + dangerouslyAllowBrowser: true, + }); + const result = await openai.chat.completions.create( + { + messages: [...toGPTMessages(context.history), ...messages], + model: 'gpt-4-vision-preview', + temperature: 0, + max_tokens: 4096, + stream: true, + }, + { signal } + ); + let text = ''; + for await (const message of result) { + text += message.choices[0].delta.content ?? ''; + yield text; + } + }, }), vendor: openaiVendor, }); diff --git a/packages/blocks/src/_common/copilot/service/service-base.ts b/packages/blocks/src/_common/copilot/service/service-base.ts index bd37932065c2..4c9e4a775018 100644 --- a/packages/blocks/src/_common/copilot/service/service-base.ts +++ b/packages/blocks/src/_common/copilot/service/service-base.ts @@ -1,7 +1,7 @@ import type { TemplateResult } from 'lit'; import type { OpenAI } from 'openai'; -import type { ChatMessage } from '../model/message-schema.js'; +import type { ChatMessage, MessageContext } from '../model/message-schema.js'; export type Vendor = { key: string; @@ -21,6 +21,10 @@ export type ServiceKind = { implList: ServiceImpl[]; implService: (impl: ServiceImpl) => void; }; +export type CopilotServiceResult = ( + context: MessageContext, + signal: AbortSignal +) => AsyncIterable; export const createVendor = ( config: Vendor ): Vendor => { @@ -45,13 +49,13 @@ const createServiceKind = (config: { }; export const TextServiceKind = createServiceKind<{ - generateText(messages: ChatMessage[]): Promise; + generateText(messages: ChatMessage[]): CopilotServiceResult; }>({ type: 'text-service', title: 'Text service', }); export const ChatServiceKind = createServiceKind<{ - chat(messages: Array): AsyncIterable; + chat(messages: Array): CopilotServiceResult; }>({ type: 'chat-service', title: 'Chat service', @@ -75,7 +79,7 @@ export const EmbeddingServiceKind = createServiceKind<{ export const Image2TextServiceKind = createServiceKind<{ generateText( messages: Array - ): AsyncIterable; + ): CopilotServiceResult; }>({ type: 'image-to-text-service', title: 'Image to text service',