Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/messages/messages.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,18 @@ export class MessagesService {
private readonly searchService: SearchService,
) {}

index(index: boolean, namespaceId: string, message: Message) {
index(
index: boolean,
namespaceId: string,
conversationId: string,
message: Message,
) {
if (index) {
this.searchService.addMessage(namespaceId, message).catch((err) => {
console.error('Failed to index message:', err);
});
this.searchService
.addMessage(namespaceId, conversationId, message)
.catch((err) => {
console.error('Failed to index message:', err);
});
}
}

Expand All @@ -41,13 +48,14 @@ export class MessagesService {
attrs: dto.attrs,
});
const savedMsg = await this.messageRepository.save(message);
this.index(index, namespaceId, savedMsg);
this.index(index, namespaceId, conversationId, savedMsg);
return savedMsg;
}

async update(
id: string,
namespaceId: string,
conversationId: string,
dto: Partial<CreateMessageDto>,
index: boolean = true,
): Promise<Message> {
Expand All @@ -58,7 +66,7 @@ export class MessagesService {
const message = await this.messageRepository.findOneOrFail(condition);
Object.assign(message, dto);
const updatedMsg = await this.messageRepository.save(message);
this.index(index, namespaceId, message);
this.index(index, namespaceId, conversationId, message);
return updatedMsg;
}

Expand Down
5 changes: 4 additions & 1 deletion src/permissions/permissions.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,12 @@ export class PermissionsService {
namespaceId: string,
resourceId: string,
): Promise<string | null> {
const resource = await this.resourceRepository.findOneOrFail({
const resource = await this.resourceRepository.findOne({
where: { namespace: { id: namespaceId }, id: resourceId },
});
if (!resource) {
return null;
}
return resource.parentId;
}
}
5 changes: 3 additions & 2 deletions src/search/dto/indexed-doc.dto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export class IndexedResourceDto {
name: string;
content: string;
_vectors?: {
default: {
omniboxEmbed: {
embeddings: number[];
regenerate: boolean;
};
Expand All @@ -19,9 +19,10 @@ export class IndexedMessageDto {
id: string;
namespaceId: string;
userId: string;
conversationId: string;
content: string;
_vectors?: {
default: {
omniboxEmbed: {
embeddings: number[];
regenerate: boolean;
};
Expand Down
18 changes: 17 additions & 1 deletion src/search/search.controller.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Controller, Get, Param, Query, Req } from '@nestjs/common';
import { SearchService } from './search.service';
import { DocType } from './doc-type.enum';
import { Public } from 'src/auth/decorators/public.decorator';

@Controller('api/v1/namespaces/:namespaceId/search')
export class SearchController {
Expand All @@ -14,10 +15,25 @@ export class SearchController {
@Query('type') type?: DocType,
) {
return await this.searchService.search(
req.user.id,
namespaceId,
query,
type,
req.user.id,
);
}
}

@Controller('internal/api/v1/namespaces/:namespaceId/search')
export class InternalSearchController {
constructor(private readonly searchService: SearchService) {}

@Public()
@Get()
async search(
@Param('namespaceId') namespaceId: string,
@Query('query') query: string,
@Query('type') type?: DocType,
) {
return await this.searchService.search(namespaceId, query, type);
}
}
7 changes: 5 additions & 2 deletions src/search/search.module.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import { Module } from '@nestjs/common';
import { SearchService } from './search.service';
import { SearchController } from './search.controller';
import {
InternalSearchController,
SearchController,
} from './search.controller';
import { PermissionsModule } from 'src/permissions/permissions.module';

@Module({
exports: [SearchService],
providers: [SearchService],
controllers: [SearchController],
controllers: [SearchController, InternalSearchController],
imports: [PermissionsModule],
})
export class SearchModule {}
90 changes: 55 additions & 35 deletions src/search/search.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import {
SearchParams,
} from 'meilisearch';
import { Resource } from 'src/resources/resources.entity';
import { Message } from 'src/messages/entities/message.entity';
import {
Message,
OpenAIMessageRole,
} from 'src/messages/entities/message.entity';
import { DocType } from './doc-type.enum';
import {
IndexedDocDto,
Expand All @@ -18,7 +21,7 @@ import {
import { PermissionsService } from 'src/permissions/permissions.service';
import { PermissionLevel } from 'src/permissions/permission-level.enum';

const indexUid = 'idx';
const indexUid = 'omniboxIdx';

@Injectable()
export class SearchService implements OnModuleInit {
Expand Down Expand Up @@ -71,9 +74,9 @@ export class SearchService implements OnModuleInit {
}

const embedders = await index.getEmbedders();
if (!embedders || !embedders.default) {
if (!embedders || !embedders.omniboxEmbed) {
await index.updateEmbedders({
default: {
omniboxEmbed: {
source: 'userProvided',
dimensions: 1024,
},
Expand All @@ -82,9 +85,6 @@ export class SearchService implements OnModuleInit {
}

async getEmbedding(input: string): Promise<number[]> {
if (!input) {
return new Array(1024).fill(0);
}
const resp = await this.openai.embeddings.create({
model: this.embeddingModel,
input,
Expand All @@ -101,17 +101,30 @@ export class SearchService implements OnModuleInit {
name: resource.name,
content: resource.content,
_vectors: {
default: {
embeddings: await this.getEmbedding(resource.content),
omniboxEmbed: {
embeddings: await this.getEmbedding(
`A resource named ${resource.name} with content: ${resource.content}`,
),
regenerate: false,
},
},
};
await index.addDocuments([doc]);
}

async addMessage(namespaceId: string, message: Message) {
if (!message.message.content) {
async addMessage(
namespaceId: string,
conversationId: string,
message: Message,
) {
if (!message.message.content?.trim()) {
return;
}
if (
[OpenAIMessageRole.TOOL, OpenAIMessageRole.SYSTEM].includes(
message.message.role,
)
) {
return;
}
const content = message.message.content;
Expand All @@ -121,10 +134,13 @@ export class SearchService implements OnModuleInit {
id: `message_${message.id}`,
namespaceId: namespaceId,
userId: message.user.id,
conversationId,
content,
_vectors: {
default: {
embeddings: await this.getEmbedding(content),
omniboxEmbed: {
embeddings: await this.getEmbedding(
`A message with content: ${content}`,
),
regenerate: false,
},
},
Expand All @@ -133,44 +149,48 @@ export class SearchService implements OnModuleInit {
}

async search(
userId: string,
namespaceId: string,
query: string,
type?: DocType,
userId?: string,
) {
const filter = [
`namespaceId = "${namespaceId}"`,
`userId NOT EXISTS OR userId = "${userId}"`,
];
const filter = [`namespaceId = "${namespaceId}"`];
if (userId) {
filter.push(`userId NOT EXISTS OR userId = "${userId}"`);
}
if (type) {
filter.push(`type = "${type}"`);
}
const searchParams: SearchParams = {
vector: await this.getEmbedding(query),
hybrid: {
embedder: 'default',
},
filter,
showRankingScore: true,
};
if (query) {
searchParams.vector = await this.getEmbedding(query);
searchParams.hybrid = {
embedder: 'omniboxEmbed',
};
}
const index = await this.meili.getIndex(indexUid);
const result = await index.search(query, searchParams);
const items: IndexedDocDto[] = [];
for (const hit of result.hits) {
hit.id = hit.id.replace(/^(message_|resource_)/, '');
if (hit.type === DocType.RESOURCE) {
const resource = hit as IndexedResourceDto;
const hasPermission = await this.permissionsService.userHasPermission(
namespaceId,
resource.id,
userId,
PermissionLevel.CAN_VIEW,
);
if (!hasPermission) {
continue;
if (userId) {
for (const hit of result.hits) {
hit.id = hit.id.replace(/^(message_|resource_)/, '');
if (hit.type === DocType.RESOURCE) {
const resource = hit as IndexedResourceDto;
const hasPermission = await this.permissionsService.userHasPermission(
namespaceId,
resource.id,
userId,
PermissionLevel.CAN_VIEW,
);
if (!hasPermission) {
continue;
}
}
items.push(hit as IndexedDocDto);
}
items.push(hit as IndexedDocDto);
}
return items;
}
Expand Down
25 changes: 17 additions & 8 deletions src/wizard/dto/chat-response.dto.ts
Original file line number Diff line number Diff line change
@@ -1,35 +1,44 @@
import { OpenAIMessage, OpenAIMessageRole, MessageAttrs } from 'src/messages/entities/message.entity';
import {
OpenAIMessage,
OpenAIMessageRole,
MessageAttrs,
} from 'src/messages/entities/message.entity';

export type ChatResponseType = "bos" | "delta" | "eos" | "done" | "error";
export type ChatResponseType = 'bos' | 'delta' | 'eos' | 'done' | 'error';

export interface ChatBaseResponse {
response_type: ChatResponseType;
}

export interface ChatBOSResponse extends ChatBaseResponse {
response_type: "bos";
response_type: 'bos';
role: OpenAIMessageRole;
id: string;
parentId?: string;
}

export interface ChatEOSResponse extends ChatBaseResponse {
response_type: "eos";
response_type: 'eos';
}

export interface ChatDeltaResponse extends ChatBaseResponse {
response_type: "delta";
response_type: 'delta';
message: Partial<OpenAIMessage>;
attrs?: MessageAttrs;
}

export interface ChatDoneResponse extends ChatBaseResponse {
response_type: "done";
response_type: 'done';
}

export interface ChatErrorResponse extends ChatBaseResponse {
response_type: "error";
response_type: 'error';
message: string;
}

export type ChatResponse = ChatBOSResponse | ChatDeltaResponse | ChatEOSResponse | ChatDoneResponse | ChatErrorResponse;
export type ChatResponse =
| ChatBOSResponse
| ChatDeltaResponse
| ChatEOSResponse
| ChatDoneResponse
| ChatErrorResponse;
Loading