Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/resources/resources.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import { Task } from 'src/tasks/tasks.entity';
import { User } from 'src/user/user.entity';
import { MinioService } from 'src/resources/minio/minio.service';
import { WizardTask } from 'src/resources/wizard.task.service';
import { PermissionLevel } from 'src/permissions/permission-level.enum';
import { SpaceType } from 'src/namespaces/entities/namespace.entity';
import { PermissionsService } from 'src/permissions/permissions.service';

Expand Down Expand Up @@ -265,7 +264,11 @@ export class ResourcesService {
);
}

async listUserResources(namespaceId: string, userId: string) {
async listUserResources(
namespaceId: string,
userId: string,
includeRoot?: boolean,
) {
const resources = await this.resourceRepository.find({
where: { namespace: { id: namespaceId }, deletedAt: IsNull() },
});
Expand All @@ -276,7 +279,7 @@ export class ResourcesService {
resource.id,
userId,
);
if (hasPermission) {
if (hasPermission && (resource.parentId !== null || includeRoot)) {
filtered.push(resource);
}
}
Expand Down
22 changes: 21 additions & 1 deletion src/wizard/dto/agent-request.dto.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
export type FloatPair = [number, number];

export class ToolDto {
name: 'knowledge_search' | 'web_search';
}

export class KnowledgeSearchToolDto extends ToolDto {
declare name: 'knowledge_search';
namespace_id: string;
resource_ids?: Array<string>;
parent_ids?: Array<string>;
created_at?: FloatPair;
updated_at?: FloatPair;
}

export class WebSearchToolDto extends ToolDto {
declare name: 'web_search';
updated_at?: FloatPair;
}

export class AgentRequestDto {
conversation_id: string;
query: string;
parent_message_id?: string;
tools?: Array<Record<string, any>>;
tools?: Array<KnowledgeSearchToolDto | WebSearchToolDto>;
enable_thinking?: boolean = true;
}
20 changes: 20 additions & 0 deletions src/wizard/stream.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ import { Observable, Subscriber } from 'rxjs';
import { MessageEvent } from '@nestjs/common';
import { Message } from 'src/messages/entities/message.entity';
import { AgentRequestDto } from 'src/wizard/dto/agent-request.dto';
import { ResourcesService } from 'src/resources/resources.service';
import { Resource } from '../resources/resources.entity';

export class StreamService {
constructor(
private readonly wizardBaseUrl: string,
private readonly messagesService: MessagesService,
private readonly resourcesService: ResourcesService,
) {}

async stream(
Expand Down Expand Up @@ -160,6 +163,23 @@ export class StreamService {
currentCiteCnt = buf.currentCiteCnt;
}

if (body.tools) {
for (const tool of body.tools) {
if (
tool.name === 'knowledge_search' &&
tool.resource_ids === undefined &&
tool.parent_ids === undefined
) {
const resources: Resource[] =
await this.resourcesService.listUserResources(
tool.namespace_id,
user.id,
);
tool.resource_ids = resources.map((r) => r.id);
}
}
}

return new Observable<MessageEvent>((subscriber) => {
const handler = this.agentHandler(body.conversation_id, user, subscriber);
this.stream(
Expand Down
6 changes: 5 additions & 1 deletion src/wizard/wizard.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ export class WizardService {
if (!baseUrl) {
throw new Error('Environment variable OBB_WIZARD_BASE_URL is required');
}
this.streamService = new StreamService(baseUrl, this.messagesService);
this.streamService = new StreamService(
baseUrl,
this.messagesService,
this.resourcesService,
);
}

async create(partialTask: Partial<Task>) {
Expand Down