Skip to content

Commit fb28e2e

Browse files
authored
Merge pull request #80 from import-ai/feature/resources_info
Add resources info into agent
2 parents 13ea0a0 + a0472ad commit fb28e2e

File tree

3 files changed

+65
-69
lines changed

3 files changed

+65
-69
lines changed

src/messages/messages.controller.ts

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,4 @@
1-
import {
2-
Body,
3-
Controller,
4-
Delete,
5-
Get,
6-
Param,
7-
Post,
8-
Req,
9-
} from '@nestjs/common';
1+
import { Body, Controller, Delete, Param, Post, Req } from '@nestjs/common';
102
import { MessagesService } from './messages.service';
113
import { CreateMessageDto } from './dto/create-message.dto';
124

@@ -16,11 +8,6 @@ import { CreateMessageDto } from './dto/create-message.dto';
168
export class MessagesController {
179
constructor(private readonly messagesService: MessagesService) {}
1810

19-
@Get()
20-
async list(@Req() req, @Param('conversationId') conversationId: string) {
21-
return await this.messagesService.findAll(req.user.id, conversationId);
22-
}
23-
2411
@Post()
2512
async create(
2613
@Req() req,
Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,41 @@
11
export type FloatPair = [number, number];
22

3-
export class ToolDto {
4-
name: 'knowledge_search' | 'web_search';
3+
export interface ToolDto {
4+
name: 'private_search' | 'web_search';
55
}
66

7-
export class KnowledgeSearchToolDto extends ToolDto {
8-
declare name: 'knowledge_search';
7+
export interface PrivateSearchResourceDto {
8+
name: string;
9+
id: string;
10+
type: 'resource' | 'folder';
11+
child_ids?: string[];
12+
}
13+
14+
export interface PrivateSearchToolDto extends ToolDto {
15+
name: 'private_search';
916
namespace_id: string;
10-
resource_ids?: Array<string>;
11-
parent_ids?: Array<string>;
12-
created_at?: FloatPair;
13-
updated_at?: FloatPair;
17+
resources?: PrivateSearchResourceDto[];
18+
visible_resource_ids?: string[];
1419
}
1520

16-
export class WebSearchToolDto extends ToolDto {
17-
declare name: 'web_search';
21+
export interface WebSearchToolDto extends ToolDto {
22+
name: 'web_search';
1823
updated_at?: FloatPair;
1924
}
2025

21-
export class AgentRequestDto {
22-
namespace_id: string;
23-
conversation_id: string;
26+
export interface BaseAgentRequestDto {
2427
query: string;
28+
conversation_id: string;
29+
tools: Array<PrivateSearchToolDto | WebSearchToolDto>;
30+
enable_thinking: boolean;
31+
}
32+
33+
export interface AgentRequestDto extends BaseAgentRequestDto {
34+
namespace_id: string;
2535
parent_message_id?: string;
26-
tools?: Array<KnowledgeSearchToolDto | WebSearchToolDto>;
27-
enable_thinking?: boolean = true;
36+
}
37+
38+
export interface WizardAgentRequestDto extends BaseAgentRequestDto {
39+
messages: Record<string, any>[];
40+
current_cite_cnt: number;
2841
}

src/wizard/stream.service.ts

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ import {
88
OpenAIMessage,
99
OpenAIMessageRole,
1010
} from 'src/messages/entities/message.entity';
11-
import { AgentRequestDto } from 'src/wizard/dto/agent-request.dto';
11+
import {
12+
AgentRequestDto,
13+
WizardAgentRequestDto,
14+
} from 'src/wizard/dto/agent-request.dto';
1215
import { ResourcesService } from 'src/resources/resources.service';
1316
import { Resource } from 'src/resources/resources.entity';
1417
import { ChatResponse } from 'src/wizard/dto/chat-response.dto';
@@ -196,7 +199,7 @@ export class StreamService {
196199
mode: 'ask' | 'write' = 'ask',
197200
): Promise<Observable<MessageEvent>> {
198201
let parentId: string | undefined = undefined;
199-
let messages: Record<string, any> = [];
202+
let messages: Record<string, any>[] = [];
200203
let currentCiteCnt: number = 0;
201204
if (body.parent_message_id) {
202205
parentId = body.parent_message_id;
@@ -211,43 +214,37 @@ export class StreamService {
211214

212215
if (body.tools) {
213216
for (const tool of body.tools) {
214-
if (tool.name === 'knowledge_search') {
215-
// for knowledge_search, pass the resource with permission
216-
if (
217-
tool.resource_ids === undefined &&
218-
tool.parent_ids === undefined
219-
) {
217+
if (tool.name === 'private_search') {
218+
// for private_search, pass the resource with permission
219+
if (!tool.resources || tool.resources.length === 0) {
220220
const resources: Resource[] =
221221
await this.resourcesService.listAllUserAccessibleResources(
222222
tool.namespace_id,
223223
user.id,
224224
);
225-
tool.resource_ids = resources.map((r) => r.id);
225+
tool.visible_resource_ids = resources.map((r) => r.id);
226226
} else {
227-
const resourceIds: string[] = [];
228-
if (tool.resource_ids) {
229-
resourceIds.push(
230-
...(await this.resourcesService.permissionFilter<string>(
231-
tool.namespace_id,
232-
user.id,
233-
tool.resource_ids,
234-
)),
235-
);
236-
}
237-
if (tool.parent_ids) {
238-
for (const parentId of tool.parent_ids) {
227+
tool.visible_resource_ids = [];
228+
tool.visible_resource_ids.push(
229+
...(await this.resourcesService.permissionFilter<string>(
230+
tool.namespace_id,
231+
user.id,
232+
tool.resources.map((r) => r.id),
233+
)),
234+
);
235+
for (const resource of tool.resources) {
236+
if (resource.type === 'folder') {
239237
const resources: Resource[] =
240238
await this.resourcesService.getAllSubResources(
241239
tool.namespace_id,
242-
parentId,
240+
resource.id,
243241
user.id,
244-
true,
242+
false,
245243
);
246-
resourceIds.push(...resources.map((res) => res.id));
244+
resource.child_ids = resources.map((r) => r.id);
245+
tool.visible_resource_ids.push(...resource.child_ids);
247246
}
248-
tool.parent_ids = undefined;
249247
}
250-
tool.resource_ids = resourceIds;
251248
}
252249
}
253250
}
@@ -265,20 +262,19 @@ export class StreamService {
265262
user,
266263
subscriber,
267264
);
268-
this.stream(
269-
`/api/v1/wizard/${mode}`,
270-
{
271-
conversation_id: body.conversation_id,
272-
query: body.query,
273-
messages,
274-
tools: body.tools,
275-
enable_thinking: body.enable_thinking,
276-
current_cite_cnt: currentCiteCnt,
277-
},
278-
async (data) => {
279-
await handler(data, handlerContext);
280-
},
281-
)
265+
266+
const wizardRequestBody: WizardAgentRequestDto = {
267+
conversation_id: body.conversation_id,
268+
query: body.query,
269+
messages,
270+
tools: body.tools,
271+
enable_thinking: body.enable_thinking,
272+
current_cite_cnt: currentCiteCnt,
273+
};
274+
275+
this.stream(`/api/v1/wizard/${mode}`, wizardRequestBody, async (data) => {
276+
await handler(data, handlerContext);
277+
})
282278
.then(() => subscriber.complete())
283279
.catch((err: Error) => this.streamError(subscriber, err));
284280
});

0 commit comments

Comments
 (0)