Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions src/messages/messages.controller.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
import {
Body,
Controller,
Delete,
Get,
Param,
Post,
Req,
} from '@nestjs/common';
import { Body, Controller, Delete, Param, Post, Req } from '@nestjs/common';
import { MessagesService } from './messages.service';
import { CreateMessageDto } from './dto/create-message.dto';

Expand All @@ -16,11 +8,6 @@ import { CreateMessageDto } from './dto/create-message.dto';
export class MessagesController {
constructor(private readonly messagesService: MessagesService) {}

@Get()
async list(@Req() req, @Param('conversationId') conversationId: string) {
return await this.messagesService.findAll(req.user.id, conversationId);
}

@Post()
async create(
@Req() req,
Expand Down
43 changes: 28 additions & 15 deletions src/wizard/dto/agent-request.dto.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,41 @@
export type FloatPair = [number, number];

export class ToolDto {
name: 'knowledge_search' | 'web_search';
export interface ToolDto {
name: 'private_search' | 'web_search';
}

export class KnowledgeSearchToolDto extends ToolDto {
declare name: 'knowledge_search';
export interface PrivateSearchResourceDto {
name: string;
id: string;
type: 'resource' | 'folder';
child_ids?: string[];
}

export interface PrivateSearchToolDto extends ToolDto {
name: 'private_search';
namespace_id: string;
resource_ids?: Array<string>;
parent_ids?: Array<string>;
created_at?: FloatPair;
updated_at?: FloatPair;
resources?: PrivateSearchResourceDto[];
visible_resource_ids?: string[];
}

export class WebSearchToolDto extends ToolDto {
declare name: 'web_search';
export interface WebSearchToolDto extends ToolDto {
name: 'web_search';
updated_at?: FloatPair;
}

export class AgentRequestDto {
namespace_id: string;
conversation_id: string;
export interface BaseAgentRequestDto {
query: string;
conversation_id: string;
tools: Array<PrivateSearchToolDto | WebSearchToolDto>;
enable_thinking: boolean;
}

export interface AgentRequestDto extends BaseAgentRequestDto {
namespace_id: string;
parent_message_id?: string;
tools?: Array<KnowledgeSearchToolDto | WebSearchToolDto>;
enable_thinking?: boolean = true;
}

export interface WizardAgentRequestDto extends BaseAgentRequestDto {
messages: Record<string, any>[];
current_cite_cnt: number;
}
76 changes: 36 additions & 40 deletions src/wizard/stream.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import {
OpenAIMessage,
OpenAIMessageRole,
} from 'src/messages/entities/message.entity';
import { AgentRequestDto } from 'src/wizard/dto/agent-request.dto';
import {
AgentRequestDto,
WizardAgentRequestDto,
} from 'src/wizard/dto/agent-request.dto';
import { ResourcesService } from 'src/resources/resources.service';
import { Resource } from 'src/resources/resources.entity';
import { ChatResponse } from 'src/wizard/dto/chat-response.dto';
Expand Down Expand Up @@ -196,7 +199,7 @@ export class StreamService {
mode: 'ask' | 'write' = 'ask',
): Promise<Observable<MessageEvent>> {
let parentId: string | undefined = undefined;
let messages: Record<string, any> = [];
let messages: Record<string, any>[] = [];
let currentCiteCnt: number = 0;
if (body.parent_message_id) {
parentId = body.parent_message_id;
Expand All @@ -211,43 +214,37 @@ export class StreamService {

if (body.tools) {
for (const tool of body.tools) {
if (tool.name === 'knowledge_search') {
// for knowledge_search, pass the resource with permission
if (
tool.resource_ids === undefined &&
tool.parent_ids === undefined
) {
if (tool.name === 'private_search') {
// for private_search, pass the resource with permission
if (!tool.resources || tool.resources.length === 0) {
const resources: Resource[] =
await this.resourcesService.listAllUserAccessibleResources(
tool.namespace_id,
user.id,
);
tool.resource_ids = resources.map((r) => r.id);
tool.visible_resource_ids = resources.map((r) => r.id);
} else {
const resourceIds: string[] = [];
if (tool.resource_ids) {
resourceIds.push(
...(await this.resourcesService.permissionFilter<string>(
tool.namespace_id,
user.id,
tool.resource_ids,
)),
);
}
if (tool.parent_ids) {
for (const parentId of tool.parent_ids) {
tool.visible_resource_ids = [];
tool.visible_resource_ids.push(
...(await this.resourcesService.permissionFilter<string>(
tool.namespace_id,
user.id,
tool.resources.map((r) => r.id),
)),
);
for (const resource of tool.resources) {
if (resource.type === 'folder') {
const resources: Resource[] =
await this.resourcesService.getAllSubResources(
tool.namespace_id,
parentId,
resource.id,
user.id,
true,
false,
);
resourceIds.push(...resources.map((res) => res.id));
resource.child_ids = resources.map((r) => r.id);
tool.visible_resource_ids.push(...resource.child_ids);
}
tool.parent_ids = undefined;
}
tool.resource_ids = resourceIds;
}
}
}
Expand All @@ -265,20 +262,19 @@ export class StreamService {
user,
subscriber,
);
this.stream(
`/api/v1/wizard/${mode}`,
{
conversation_id: body.conversation_id,
query: body.query,
messages,
tools: body.tools,
enable_thinking: body.enable_thinking,
current_cite_cnt: currentCiteCnt,
},
async (data) => {
await handler(data, handlerContext);
},
)

const wizardRequestBody: WizardAgentRequestDto = {
conversation_id: body.conversation_id,
query: body.query,
messages,
tools: body.tools,
enable_thinking: body.enable_thinking,
current_cite_cnt: currentCiteCnt,
};

this.stream(`/api/v1/wizard/${mode}`, wizardRequestBody, async (data) => {
await handler(data, handlerContext);
})
.then(() => subscriber.complete())
.catch((err: Error) => this.streamError(subscriber, err));
});
Expand Down