Skip to content

Commit

Permalink
feat(blocks): support real abort for copilot (#6530)
Browse files Browse the repository at this point in the history
  • Loading branch information
Saul-Mirone committed Mar 21, 2024
1 parent 3054548 commit 8f7221a
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 119 deletions.
13 changes: 8 additions & 5 deletions packages/blocks/src/_common/copilot/model/chat-history.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ import { html, type TemplateResult } from 'lit';
import { customElement, property } from 'lit/decorators.js';
import { repeat } from 'lit/directives/repeat.js';

import type { CopilotServiceResult } from '../service/service-base.js';
import type {
ApiData,
ChatMessage,
MessageContent,
MessageContext,
MessageSchema,
UserChatMessage,
} from './message-schema.js';
import { MessageSchemas } from './message-type/index.js';

export type CopilotAction<Result> = {
type: string;
run: (context: MessageContext) => AsyncIterable<Result>;
run: CopilotServiceResult<Result>;
};

export interface HistoryItem {
Expand Down Expand Up @@ -77,9 +77,12 @@ export class AssistantHistoryItem<Result = unknown, Data = unknown>
this.stop();
const abortController = new AbortController();
this.abortController = abortController;
const result = this.action.run({
history: this.history.flatMap(v => v.toContext()),
});
const result = this.action.run(
{
history: this.history.flatMap(v => v.toContext()),
},
abortController.signal
);
const process = async () => {
let lastValue: Result | undefined;
for await (const value of result) {
Expand Down
14 changes: 8 additions & 6 deletions packages/blocks/src/_common/copilot/model/message-schema.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { EditorHost } from '@blocksuite/block-std';
import type { TemplateResult } from 'lit';

import type { CopilotServiceResult } from '../service/service-base.js';
import type { CopilotAction } from './chat-history.js';

export type MessageContent =
Expand Down Expand Up @@ -28,12 +29,13 @@ export type AssistantChatMessage = {
content: string;
sources: BackgroundSource[];
};
export type SystemChatMessage = {
role: 'system';
content: string;
};
export type ChatMessage =
| UserChatMessage
| {
role: 'system';
content: string;
}
| SystemChatMessage
| AssistantChatMessage;

export type ApiData<T> =
Expand Down Expand Up @@ -64,15 +66,15 @@ export const createMessageSchema = <Result, Data = unknown>(
config: MessageSchema<Result, Data>
): MessageSchema<Result, Data> & {
createActionBuilder: <Arg>(
fn: (arg: Arg, context: MessageContext) => AsyncIterable<Result>
fn: (arg: Arg) => CopilotServiceResult<Result>
) => (arg: Arg) => CopilotAction<Result>;
} => {
return {
...config,
createActionBuilder: fn => arg => {
return {
type: config.type,
run: context => fn(arg, context),
run: fn(arg),
};
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,8 @@ When sent new wireframes, respond ONLY with the contents of the html file.`,
);

export const createHTMLFromTextAction = HTMLMessageSchema.createActionBuilder(
(text: string, context) => {
(text: string) => {
return chatService().chat([
...context.history,
userText(
`You are a professional web developer who specializes in building working website prototypes from product requirement descriptions.
Your job is to take a product requirement description, then create a working prototype using HTML, CSS, and JavaScript, and finally send the result back.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ export const MindMapMessageSchema = createMessageSchema<Markdown>({
});

export const createMindMapAction = MindMapMessageSchema.createActionBuilder(
(text: string, context) => {
(text: string) => {
return chatService().chat([
...context.history,
userText(
`Use the nested unordered list syntax in Markdown to create a structure similar to a mind map.
Analyze the following questions:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { chatService, userText } from '../utils.js';
import { TextMessageSchema } from './index.js';

export const createCommonTextAction = TextMessageSchema.createActionBuilder(
(text: string, context) => {
return chatService().chat([...context.history, userText(text)]);
(text: string) => {
return chatService().chat([userText(text)]);
}
);
export const createChangeToneAction = TextMessageSchema.createActionBuilder(
Expand Down
82 changes: 43 additions & 39 deletions packages/blocks/src/_common/copilot/service/llama2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,67 +35,71 @@ export const llama2Vendor = createVendor<{
TextServiceKind.implService({
name: 'llama2',
method: data => ({
generateText: async messages => {
const result: {
message: {
role: string;
content: string;
};
} = await fetch(`${data.host}/api/chat`, {
method: 'POST',
body: JSON.stringify({
model: 'llama2',
messages: messages,
stream: false,
}),
}).then(res => res.json());
return result.message.content;
},
generateText: messages =>
async function* (context, signal) {
const result: {
message: {
role: string;
content: string;
};
} = await fetch(`${data.host}/api/chat`, {
method: 'POST',
signal,
body: JSON.stringify({
model: 'llama2',
messages: [...context.history, ...messages],
stream: false,
}),
}).then(res => res.json());
yield result.message.content;
},
}),
vendor: llama2Vendor,
});

ChatServiceKind.implService({
name: 'llama2',
method: data => ({
chat: messages => {
const llama2Messages = messages.map(message => {
if (message.role === 'user') {
let text = '';
const imgs: string[] = [];
message.content.forEach(v => {
if (v.type === 'text') {
text += `${v.text}\n`;
}
if (v.type === 'image_url') {
imgs.push(v.image_url.url.split(',')[1]);
chat: messages =>
async function* (context, signal) {
const llama2Messages = [...context.history, ...messages].map(
message => {
if (message.role === 'user') {
let text = '';
const imgs: string[] = [];
message.content.forEach(v => {
if (v.type === 'text') {
text += `${v.text}\n`;
}
if (v.type === 'image_url') {
imgs.push(v.image_url.url.split(',')[1]);
}
});
return {
role: message.role,
content: text,
images: imgs,
};
}
});
return {
role: message.role,
content: text,
images: imgs,
};
}
return message;
});
return (async function* () {
return message;
}
);
const result: {
message: {
role: string;
content: string;
};
} = await fetch(`${data.host}/api/chat`, {
method: 'POST',
signal,
body: JSON.stringify({
model: 'llama2',
messages: llama2Messages,
stream: false,
}),
}).then(res => res.json());
yield result.message.content;
})();
},
},
}),
vendor: llama2Vendor,
});
Loading

0 comments on commit 8f7221a

Please sign in to comment.