Skip to content

Commit a80664d

Browse files
authored
Merge pull request #83 from import-ai/fix/multi_round
fix(chat): Fix multi round error
2 parents 60dafcd + 47378a7 commit a80664d

File tree

2 files changed

+6
-18
lines changed

2 files changed

+6
-18
lines changed

src/wizard/dto/agent-request.dto.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { Message } from 'src/messages/entities/message.entity';
12
export type FloatPair = [number, number];
23

34
export interface ToolDto {
@@ -36,6 +37,5 @@ export interface AgentRequestDto extends BaseAgentRequestDto {
3637
}
3738

3839
export interface WizardAgentRequestDto extends BaseAgentRequestDto {
39-
messages: Record<string, any>[];
40-
current_cite_cnt: number;
40+
messages: Message[];
4141
}

src/wizard/stream.service.ts

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,23 +164,15 @@ export class StreamService {
164164
return message;
165165
}
166166

167-
getMessages(
168-
allMessages: Message[],
169-
parentMessageId: string,
170-
): { messages: Record<string, any>[]; currentCiteCnt: number } {
167+
getMessages(allMessages: Message[], parentMessageId: string): Message[] {
171168
const messages: Message[] = [];
172-
let currentCiteCnt: number = 0;
173169
let parentId: string | undefined = parentMessageId;
174170
while (parentId) {
175171
const message = this.findOneOrFail(allMessages, parentId);
176-
const attrs = message.attrs as { citations: Record<string, any>[] };
177-
if (attrs?.citations) {
178-
currentCiteCnt += attrs.citations.length;
179-
}
180172
messages.unshift(message);
181173
parentId = message.parentId;
182174
}
183-
return { messages: messages.map((m) => m.message), currentCiteCnt };
175+
return messages;
184176
}
185177

186178
streamError(subscriber: Subscriber<MessageEvent>, err: Error) {
@@ -199,17 +191,14 @@ export class StreamService {
199191
mode: 'ask' | 'write' = 'ask',
200192
): Promise<Observable<MessageEvent>> {
201193
let parentId: string | undefined = undefined;
202-
let messages: Record<string, any>[] = [];
203-
let currentCiteCnt: number = 0;
194+
let messages: Message[] = [];
204195
if (body.parent_message_id) {
205196
parentId = body.parent_message_id;
206197
const allMessages = await this.messagesService.findAll(
207198
user.id,
208199
body.conversation_id,
209200
);
210-
const buf = this.getMessages(allMessages, parentId);
211-
messages = buf.messages;
212-
currentCiteCnt = buf.currentCiteCnt;
201+
messages = this.getMessages(allMessages, parentId);
213202
}
214203

215204
if (body.tools) {
@@ -269,7 +258,6 @@ export class StreamService {
269258
messages,
270259
tools: body.tools,
271260
enable_thinking: body.enable_thinking,
272-
current_cite_cnt: currentCiteCnt,
273261
};
274262

275263
this.stream(`/api/v1/wizard/${mode}`, wizardRequestBody, async (data) => {

0 commit comments

Comments
 (0)