Skip to content

Commit

Permalink
Merge pull request #25 from refly-ai/feat/summarize-conversation
Browse files Browse the repository at this point in the history
feat(reflyd): auto summary of conversation into titles
  • Loading branch information
mrcfps authored Apr 24, 2024
2 parents 5f8d991 + 96c14ee commit a9e0a09
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 151 deletions.
3 changes: 3 additions & 0 deletions reflyd/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@
"node_modules",
"src"
],
"moduleNameMapper": {
"^src/(.*)$": "<rootDir>/$1"
},
"testTimeout": 40000,
"testRegex": ".*\\.spec\\.ts$",
"transform": {
Expand Down
8 changes: 4 additions & 4 deletions reflyd/src/aigc/aigc.service.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import { Injectable, Logger } from '@nestjs/common';
import { Injectable } from '@nestjs/common';
import { Document } from '@langchain/core/documents';
import omit from 'lodash.omit';

import { LlmService } from '../llm/llm.service';
import { PrismaService } from '../common/prisma.service';
import { AigcContent, User, UserWeblink, Weblink } from '@prisma/client';
import { ContentMeta } from 'src/llm/dto';
import { WebLinkDTO } from 'src/weblink/dto';
import { ContentMeta } from '../llm/dto';
import { WebLinkDTO } from '../weblink/dto';
import { DigestFilter } from './aigc.dto';
import { categoryList } from '../prompts/utils/category';
import { LoggerService } from 'src/common/logger.service';
import { LoggerService } from '../common/logger.service';

@Injectable()
export class AigcService {
Expand Down
121 changes: 26 additions & 95 deletions reflyd/src/conversation/conversation.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ import {
} from '@nestjs/common';
import { Response } from 'express';
import {
CreateChatMessageInput,
CreateConversationParam,
CreateConversationResponse,
ListConversationResponse,
} from './dto';
import { ApiParam, ApiResponse } from '@nestjs/swagger';
import { JwtAuthGuard } from '../auth/guard/jwt-auth.guard';
import { ConversationService } from './conversation.service';
import { TASK_TYPE, TaskResponse, type Task } from '../types/task';
import { TASK_TYPE, type Task } from '../types/task';
import { AigcService } from '../aigc/aigc.service';
import { LoggerService } from '../common/logger.service';

Expand Down Expand Up @@ -47,7 +48,7 @@ export class ConversationController {
const content = await this.aigcService.getContent({
contentId: body.contentId,
});
await this.conversationService.addChatMessages([
const messages: CreateChatMessageInput[] = [
{
type: 'human',
content: content.title,
Expand All @@ -64,11 +65,14 @@ export class ConversationController {
conversationId: res.id,
locale: body.locale,
},
];
await Promise.all([
this.conversationService.addChatMessages(messages),
this.conversationService.updateConversation(res.id, messages, {
messageCount: { increment: 2 },
lastMessage: content.content,
}),
]);
await this.conversationService.updateConversation(res.id, {
messageCount: { increment: 2 },
lastMessage: content.content,
});
}

return {
Expand All @@ -84,97 +88,24 @@ export class ConversationController {
@Body() body: { task: Task },
@Res() res: Response,
) {
try {
if (!conversationId || !Number(conversationId)) {
throw new BadRequestException('invalid conversation id');
}
const convId = Number(conversationId);

const { taskType, data = {} } = body?.task;
if (taskType === TASK_TYPE.CHAT && !data?.question) {
throw new BadRequestException('query cannot be empty');
}

const userId: number = req.user.id;
const query = data?.question || '';
const weblinkList = body?.task?.data?.filter?.weblinkList || [];

await this.conversationService.addChatMessage({
type: 'human',
userId,
conversationId: convId,
content: query,
sources: '',
locale: body.task.locale,
// 每次提问完在 human message 上加一个提问的 filter,这样之后追问时可以 follow 这个 filter 规则
selectedWeblinkConfig: JSON.stringify({
searchTarget: weblinkList?.length > 0 ? 'selectedPages' : 'all',
filter: weblinkList,
}),
});

res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
res.status(200);

// 获取聊天历史
const chatHistory = await this.conversationService.getMessages(convId);

let taskRes: TaskResponse;
if (taskType === TASK_TYPE.QUICK_ACTION) {
taskRes = await this.conversationService.handleQuickActionTask(
req,
res,
body?.task,
chatHistory,
);
} else if (taskType === TASK_TYPE.SEARCH_ENHANCE_ASK) {
taskRes = await this.conversationService.handleSearchEnhanceTask(
req,
res,
body?.task,
chatHistory,
);
} else {
taskRes = await this.conversationService.handleChatTask(
req,
res,
body?.task,
chatHistory,
);
}
res.end(``);

await this.conversationService.addChatMessage({
type: 'ai',
userId,
conversationId: convId,
content: taskRes.answer,
locale: body.task.locale,
sources: JSON.stringify(taskRes.sources),
relatedQuestions: JSON.stringify(taskRes.relatedQuestions),
});
if (!conversationId || !Number(conversationId)) {
throw new BadRequestException('invalid conversation id');
}
const convId = Number(conversationId);
const { task } = body;
if (!task) {
throw new BadRequestException('task cannot be empty');
}
if (task.taskType === TASK_TYPE.CHAT && !task.data?.question) {
throw new BadRequestException('query cannot be empty for chat task');
}

// update conversation last answer and message count
const updated = await this.conversationService.updateConversation(
convId,
{
lastMessage: taskRes.answer,
messageCount: chatHistory.length + 1,
},
);
this.logger.log(
`update conversation ${convId}, after updated: ${JSON.stringify(
updated,
)}`,
);
} catch (err) {
this.logger.error(`chat error: ${err}`);
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
res.status(200);

// 结束流式输出
res.end(``);
}
await this.conversationService.chat(res, convId, req.user.id, body.task);
}

@UseGuards(JwtAuthGuard)
Expand Down
106 changes: 69 additions & 37 deletions reflyd/src/conversation/conversation.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ import { Injectable } from '@nestjs/common';
import { Response } from 'express';

import { PrismaService } from '../common/prisma.service';
import { CreateConversationParam } from './dto';
import { MessageType, Prisma, ChatMessage } from '@prisma/client';
import { CreateChatMessageInput, CreateConversationParam } from './dto';
import { Prisma, ChatMessage } from '@prisma/client';
import {
LOCALE,
QUICK_ACTION_TASK_PAYLOAD,
QUICK_ACTION_TYPE,
TASK_TYPE,
Task,
TaskResponse,
} from '../types/task';
Expand Down Expand Up @@ -46,40 +47,21 @@ export class ConversationService {

async updateConversation(
conversationId: number,
messages: { type: string; content: string }[],
data: Prisma.ConversationUpdateInput,
) {
const summarizedTitle = await this.llmService.summarizeConversation(
messages,
);
this.logger.log(`Summarized title: ${summarizedTitle}`);

return this.prisma.conversation.update({
where: { id: conversationId },
data,
});
}

async addChatMessage(msg: {
type: MessageType;
sources: string;
content: string;
userId: number;
conversationId: number;
locale?: string;
relatedQuestions?: string;
selectedWeblinkConfig?: string;
}) {
return this.prisma.chatMessage.create({
data: { ...msg },
data: { ...data, title: summarizedTitle },
});
}

async addChatMessages(
msgList: {
type: MessageType;
sources: string;
content: string;
userId: number;
conversationId: number;
locale?: string;
selectedWeblinkConfig?: string;
}[],
) {
async addChatMessages(msgList: CreateChatMessageInput[]) {
return this.prisma.chatMessage.createMany({
data: msgList,
});
Expand Down Expand Up @@ -111,13 +93,64 @@ export class ConversationService {
});
}

async chat(res: Response, convId: number, userId: number, task: Task) {
const { taskType, data = {} } = task;

const query = data?.question || '';
const weblinkList = data?.filter?.weblinkList || [];

// 获取聊天历史
const chatHistory = await this.getMessages(convId);

let taskRes: TaskResponse;
if (taskType === TASK_TYPE.QUICK_ACTION) {
taskRes = await this.handleQuickActionTask(res, userId, task);
} else if (taskType === TASK_TYPE.SEARCH_ENHANCE_ASK) {
taskRes = await this.handleSearchEnhanceTask(res, task, chatHistory);
} else {
taskRes = await this.handleChatTask(res, userId, task, chatHistory);
}
res.end(``);

const newMessages: CreateChatMessageInput[] = [
{
type: 'human',
userId,
conversationId: convId,
content: query,
sources: '',
// 每次提问完在 human message 上加一个提问的 filter,这样之后追问时可以 follow 这个 filter 规则
selectedWeblinkConfig: JSON.stringify({
searchTarget: weblinkList?.length > 0 ? 'selectedPages' : 'all',
filter: weblinkList,
}),
},
{
type: 'ai',
userId,
conversationId: convId,
content: taskRes.answer,
sources: JSON.stringify(taskRes.sources),
relatedQuestions: JSON.stringify(taskRes.relatedQuestions),
},
];

// post chat logic
await Promise.all([
this.addChatMessages(newMessages),
this.updateConversation(convId, [...chatHistory, ...newMessages], {
lastMessage: taskRes.answer,
messageCount: chatHistory.length + 2,
}),
]);
}

async handleChatTask(
req: any,
res: Response,
userId: number,
task: Task,
chatHistory: ChatMessage[],
): Promise<TaskResponse> {
const userId: number = req.user?.id;
const locale = task?.locale || LOCALE.EN;

const filter: any = {
Expand Down Expand Up @@ -196,7 +229,9 @@ export class ConversationService {
this.logger.log('relatedQuestions', relatedQuestions);

res.write(RELATED_SPLIT);
res.write(JSON.stringify(relatedQuestions));
if (relatedQuestions) {
res.write(JSON.stringify(relatedQuestions));
}

return {
sources,
Expand All @@ -206,7 +241,6 @@ export class ConversationService {
}

async handleSearchEnhanceTask(
req: any,
res: Response,
task: Task,
chatHistory: ChatMessage[],
Expand Down Expand Up @@ -265,10 +299,9 @@ export class ConversationService {
}

async handleQuickActionTask(
req: any,
res: Response,
userId: number,
task: Task,
chatHistory: ChatMessage[],
): Promise<TaskResponse> {
const data = task?.data as QUICK_ACTION_TASK_PAYLOAD;
const locale = task?.locale || LOCALE.EN;
Expand Down Expand Up @@ -296,9 +329,8 @@ export class ConversationService {

// save user mark for each weblink in a non-blocking style
this.weblinkService.saveWeblinkUserMarks({
userId: req.user.id,
userId,
weblinkList,
extensionVersion: req.header('x-refly-ext-version'),
});

// 基于一组网页做总结,先获取网页内容
Expand Down
11 changes: 11 additions & 0 deletions reflyd/src/conversation/dto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,14 @@ export class ChatParam {
@ApiPropertyOptional()
conversationId: number;
}

export interface CreateChatMessageInput {
type: MessageType;
sources: string;
content: string;
userId: number;
locale?: string;
conversationId: number;
relatedQuestions?: string;
selectedWeblinkConfig?: string;
}
Loading

0 comments on commit a9e0a09

Please sign in to comment.