Skip to content

Commit 09081b0

Browse files
authored
Merge pull request #66 from import-ai/fix/knowledge_search
Fix/knowledge search
2 parents 7e90810 + 41fd86a commit 09081b0

File tree

5 files changed

+121
-38
lines changed

5 files changed

+121
-38
lines changed

src/messages/entities/message.entity.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@ import { Conversation } from 'src/conversations/entities/conversation.entity';
99
import { User } from 'src/user/user.entity';
1010
import { Base } from 'src/common/base.entity';
1111

12+
/**
13+
* Every message has a `parentId` that points to its preceding message.
14+
* This structure supports two main scenarios:
15+
*
16+
* 1. **Regenerating the LLM’s response**
17+
* - Retrying a failed or incomplete response
18+
* - Replacing a response that was inaccurate or irrelevant
19+
* 2. **Editing the user’s query message**
20+
*/
1221
@Entity('messages')
1322
export class Message extends Base {
1423
@PrimaryGeneratedColumn('uuid')

src/resources/resources.service.ts

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,33 @@ export class ResourcesService {
108108
return await this.create(user, newResource);
109109
}
110110

111-
async query({ namespaceId, parentId, spaceType, tags, userId }: IQuery) {
111+
async permissionFilter(
112+
namespaceId: string,
113+
userId: string,
114+
resources: Resource[],
115+
): Promise<Resource[]> {
116+
const filteredResources: Resource[] = [];
117+
for (const res of resources) {
118+
const hasPermission: boolean =
119+
await this.permissionsService.userHasPermission(
120+
namespaceId,
121+
res.id,
122+
userId,
123+
);
124+
if (hasPermission) {
125+
filteredResources.push(res);
126+
}
127+
}
128+
return filteredResources;
129+
}
130+
131+
// get resources under parentId
132+
async queryV2(
133+
namespaceId: string,
134+
parentId: string,
135+
userId?: string, // if is undefined, would skip the permission filter
136+
tags?: string, // separated by `,`
137+
): Promise<Resource[]> {
112138
const where: FindOptionsWhere<Resource> = {
113139
namespace: { id: namespaceId },
114140
parentId,
@@ -124,22 +150,40 @@ export class ResourcesService {
124150
where,
125151
relations: ['namespace'],
126152
});
127-
const filteredResources: Resource[] = [];
128-
for (const resource of resources) {
129-
const hasPermission = await this.permissionsService.userHasPermission(
130-
namespaceId,
131-
resource.id,
132-
userId,
133-
);
134-
if (hasPermission) {
135-
filteredResources.push(resource);
136-
}
137-
}
153+
return userId
154+
? await this.permissionFilter(namespaceId, userId, resources)
155+
: resources;
156+
}
157+
158+
async query({ namespaceId, parentId, spaceType, tags, userId }: IQuery) {
159+
const filteredResources: Resource[] = await this.queryV2(
160+
namespaceId,
161+
parentId,
162+
userId,
163+
tags,
164+
);
138165
return filteredResources.map((res) => {
139166
return { ...res, spaceType };
140167
});
141168
}
142169

170+
async getAllSubResources(
171+
namespaceId: string,
172+
parentId: string,
173+
userId?: string,
174+
includeRoot: boolean = false,
175+
): Promise<Resource[]> {
176+
let resources: Resource[] = [await this.get(parentId)];
177+
for (const res of resources) {
178+
const subResources: Resource[] = await this.queryV2(namespaceId, res.id);
179+
resources.push(...subResources);
180+
}
181+
resources = includeRoot ? resources : resources.slice(1);
182+
return userId
183+
? await this.permissionFilter(namespaceId, userId, resources)
184+
: resources;
185+
}
186+
143187
async getSpaceType(resource: Resource): Promise<SpaceType> {
144188
while (resource.parentId !== null) {
145189
resource = (await this.resourceRepository.findOne({
@@ -280,25 +324,18 @@ export class ResourcesService {
280324
);
281325
}
282326

283-
async listUserAccessibleResources(
327+
async listAllUserAccessibleResources(
284328
namespaceId: string,
285329
userId: string,
286-
includeRoot?: boolean,
330+
includeRoot: boolean = false,
287331
) {
288332
const resources = await this.resourceRepository.find({
289333
where: { namespace: { id: namespaceId }, deletedAt: IsNull() },
290334
});
291-
const filtered: Resource[] = [];
292-
for (const resource of resources) {
293-
const hasPermission = await this.permissionsService.userHasPermission(
294-
namespaceId,
295-
resource.id,
296-
userId,
297-
);
298-
if (hasPermission && (resource.parentId !== null || includeRoot)) {
299-
filtered.push(resource);
300-
}
301-
}
302-
return filtered;
335+
return await this.permissionFilter(
336+
namespaceId,
337+
userId,
338+
resources.filter((res) => res.parentId !== null || includeRoot),
339+
);
303340
}
304341
}

src/wizard/stream.service.ts

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ import { Message } from 'src/messages/entities/message.entity';
66
import { AgentRequestDto } from 'src/wizard/dto/agent-request.dto';
77
import { ResourcesService } from 'src/resources/resources.service';
88
import { Resource } from 'src/resources/resources.entity';
9+
import { PermissionsService } from 'src/permissions/permissions.service';
910

1011
export class StreamService {
1112
constructor(
1213
private readonly wizardBaseUrl: string,
1314
private readonly messagesService: MessagesService,
1415
private readonly resourcesService: ResourcesService,
16+
private readonly permissionsService: PermissionsService,
1517
) {}
1618

1719
async stream(
@@ -165,18 +167,48 @@ export class StreamService {
165167

166168
if (body.tools) {
167169
for (const tool of body.tools) {
168-
if (
169-
tool.name === 'knowledge_search' &&
170-
tool.resource_ids === undefined &&
171-
tool.parent_ids === undefined
172-
) {
170+
if (tool.name === 'knowledge_search') {
173171
// for knowledge_search, pass the resource with permission
174-
const resources: Resource[] =
175-
await this.resourcesService.listUserAccessibleResources(
176-
tool.namespace_id,
177-
user.id,
178-
);
179-
tool.resource_ids = resources.map((r) => r.id);
172+
if (
173+
tool.resource_ids === undefined &&
174+
tool.parent_ids === undefined
175+
) {
176+
const resources: Resource[] =
177+
await this.resourcesService.listAllUserAccessibleResources(
178+
tool.namespace_id,
179+
user.id,
180+
);
181+
tool.resource_ids = resources.map((r) => r.id);
182+
} else {
183+
const resourceIds: string[] = [];
184+
if (tool.resource_ids) {
185+
for (const resourceId of tool.resource_ids) {
186+
const hasPermission: boolean =
187+
await this.permissionsService.userHasPermission(
188+
tool.namespace_id,
189+
resourceId,
190+
user.id,
191+
);
192+
if (hasPermission) {
193+
resourceIds.push(resourceId);
194+
}
195+
}
196+
}
197+
if (tool.parent_ids) {
198+
for (const parentId of tool.parent_ids) {
199+
const resources: Resource[] =
200+
await this.resourcesService.getAllSubResources(
201+
tool.namespace_id,
202+
parentId,
203+
user.id,
204+
true,
205+
);
206+
resourceIds.push(...resources.map((res) => res.id));
207+
}
208+
tool.parent_ids = undefined;
209+
}
210+
tool.resource_ids = resourceIds;
211+
}
180212
}
181213
}
182214
}

src/wizard/wizard.module.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ import { ResourcesModule } from 'src/resources/resources.module';
77
import { TypeOrmModule } from '@nestjs/typeorm';
88
import { Task } from 'src/tasks/tasks.entity';
99
import { MessagesModule } from 'src/messages/messages.module';
10+
import { PermissionsModule } from '../permissions/permissions.module';
1011

1112
@Module({
1213
providers: [WizardService],
1314
imports: [
1415
NamespacesModule,
1516
ResourcesModule,
1617
MessagesModule,
18+
PermissionsModule,
1719
TypeOrmModule.forFeature([Task]),
1820
],
1921
controllers: [WizardController, InternalWizardController],

src/wizard/wizard.service.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ import { ReaderProcessor } from 'src/wizard/processors/reader.processor';
1515
import { Processor } from 'src/wizard/processors/processor.abstract';
1616
import { MessagesService } from 'src/messages/messages.service';
1717
import { StreamService } from 'src/wizard/stream.service';
18-
import { WizardAPIService } from './api.wizard.service';
18+
import { WizardAPIService } from 'src/wizard/api.wizard.service';
19+
import { PermissionsService } from 'src/permissions/permissions.service';
1920

2021
@Injectable()
2122
export class WizardService {
@@ -29,6 +30,7 @@ export class WizardService {
2930
private readonly resourcesService: ResourcesService,
3031
private readonly messagesService: MessagesService,
3132
private readonly configService: ConfigService,
33+
private readonly permissionsService: PermissionsService,
3234
) {
3335
this.processors = {
3436
collect: new CollectProcessor(resourcesService),
@@ -42,6 +44,7 @@ export class WizardService {
4244
baseUrl,
4345
this.messagesService,
4446
this.resourcesService,
47+
this.permissionsService,
4548
);
4649
this.wizardApiService = new WizardAPIService(baseUrl);
4750
}

0 commit comments

Comments
 (0)