Skip to content

Commit 9e3ecf1

Browse files
committed
WIP - insert many with generated embeddings
This is still very WIP and not ready for review, just running accuracy tests.
1 parent fe38463 commit 9e3ecf1

File tree

2 files changed

+522
-11
lines changed

2 files changed

+522
-11
lines changed

src/tools/mongodb/create/insertMany.ts

Lines changed: 207 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
33
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
44
import { type ToolArgs, type OperationType, formatUntrustedData } from "../../tool.js";
55
import { zEJSON } from "../../args.js";
6+
import { type Document } from "bson";
7+
import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js";
8+
import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
9+
import type { VectorFieldIndexDefinition } from "../../../common/search/vectorSearchEmbeddingsManager.js";
610

711
export class InsertManyTool extends MongoDBToolBase {
812
public name = "insert-many";
@@ -12,7 +16,19 @@ export class InsertManyTool extends MongoDBToolBase {
1216
documents: z
1317
.array(zEJSON().describe("An individual MongoDB document"))
1418
.describe(
15-
"The array of documents to insert, matching the syntax of the document argument of db.collection.insertMany()"
19+
"The array of documents to insert, matching the syntax of the document argument of db.collection.insertMany(). For fields that have vector search indexes, you can provide raw text strings that will be automatically converted to embeddings if embeddingParameters is provided."
20+
),
21+
embeddingParameters: zSupportedEmbeddingParameters
22+
.extend({
23+
input: z
24+
.array(z.record(z.string(), z.string()).optional())
25+
.describe(
26+
"An array of objects (one per document) that maps field paths to plain-text content for generating embeddings. Each object should have keys matching the vector search index field paths (in dot notation), with string values containing the text to embed. If provided, these texts will be used to generate embeddings instead of looking for raw text in the document fields themselves. Example: [{'content': 'Input text to create embeddings from for first doc'}, {'content': 'Input text to create embeddings from for second doc'}]"
27+
),
28+
})
29+
.optional()
30+
.describe(
31+
"The embedding model and its parameters to use to generate embeddings for fields that have vector search indexes. Required when fields associated with vector search indexes contain raw text strings. Note to LLM: If unsure, ask the user before providing one."
1632
),
1733
};
1834
public operationType: OperationType = "create";
@@ -21,23 +37,36 @@ export class InsertManyTool extends MongoDBToolBase {
2137
database,
2238
collection,
2339
documents,
40+
embeddingsInput,
41+
embeddingParameters,
2442
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
2543
const provider = await this.ensureConnected();
2644

27-
const embeddingValidations = new Set(
28-
...(await Promise.all(
29-
documents.flatMap((document) =>
30-
this.session.vectorSearchEmbeddingsManager.findFieldsWithWrongEmbeddings(
31-
{ database, collection },
32-
document
33-
)
34-
)
35-
))
45+
// Get vector search indexes for the collection
46+
const vectorIndexes = await this.session.vectorSearchEmbeddingsManager.embeddingsForNamespace({
47+
database,
48+
collection,
49+
});
50+
51+
// Process documents to replace raw string values with generated embeddings
52+
documents = await this.replaceRawValuesWithEmbeddingsIfNecessary({
53+
database,
54+
collection,
55+
documents,
56+
vectorIndexes,
57+
embeddingsInput,
58+
embeddingParameters,
59+
});
60+
61+
const embeddingValidationPromises = documents.map((document) =>
62+
this.session.vectorSearchEmbeddingsManager.findFieldsWithWrongEmbeddings({ database, collection }, document)
3663
);
64+
const embeddingValidationResults = await Promise.all(embeddingValidationPromises);
65+
const embeddingValidations = new Set(embeddingValidationResults.flat());
3766

3867
if (embeddingValidations.size > 0) {
3968
// tell the LLM what happened
40-
const embeddingValidationMessages = [...embeddingValidations].map(
69+
const embeddingValidationMessages = Array.from(embeddingValidations).map(
4170
(validation) =>
4271
`- Field ${validation.path} is an embedding with ${validation.expectedNumDimensions} dimensions and ${validation.expectedQuantization}` +
4372
` quantization, and the provided value is not compatible. Actual dimensions: ${validation.actualNumDimensions}, ` +
@@ -63,4 +92,171 @@ export class InsertManyTool extends MongoDBToolBase {
6392
content,
6493
};
6594
}
95+
96+
private async replaceRawValuesWithEmbeddingsIfNecessary({
97+
database,
98+
collection,
99+
documents,
100+
vectorIndexes,
101+
embeddingsInput,
102+
embeddingParameters,
103+
}: {
104+
database: string;
105+
collection: string;
106+
documents: Document[];
107+
vectorIndexes: VectorFieldIndexDefinition[];
108+
embeddingsInput?: Array<Record<string, string> | undefined>;
109+
embeddingParameters?: z.infer<typeof zSupportedEmbeddingParameters>;
110+
}): Promise<Document[]> {
111+
// If no vector indexes, return documents as-is
112+
if (vectorIndexes.length === 0) {
113+
return documents;
114+
}
115+
116+
const processedDocuments: Document[] = [];
117+
118+
for (let i = 0; i < documents.length; i++) {
119+
const document = documents[i];
120+
if (!document) {
121+
continue;
122+
}
123+
const documentEmbeddingsInput = embeddingsInput?.[i];
124+
const processedDoc = await this.processDocumentForEmbeddings(
125+
database,
126+
collection,
127+
document,
128+
vectorIndexes,
129+
documentEmbeddingsInput,
130+
embeddingParameters
131+
);
132+
processedDocuments.push(processedDoc);
133+
}
134+
135+
return processedDocuments;
136+
}
137+
138+
private async processDocumentForEmbeddings(
139+
database: string,
140+
collection: string,
141+
document: Document,
142+
vectorIndexes: VectorFieldIndexDefinition[],
143+
embeddingsInput?: Record<string, string>,
144+
embeddingParameters?: z.infer<typeof zSupportedEmbeddingParameters>
145+
): Promise<Document> {
146+
// Find all fields in the document that match vector search indexed fields and need embeddings
147+
const fieldsNeedingEmbeddings: Array<{
148+
path: string;
149+
rawValue: string;
150+
indexDef: VectorFieldIndexDefinition;
151+
}> = [];
152+
153+
for (const indexDef of vectorIndexes) {
154+
// First, check if embeddingsInput provides text for this field
155+
if (embeddingsInput && indexDef.path in embeddingsInput) {
156+
const inputText = embeddingsInput[indexDef.path];
157+
if (typeof inputText === "string" && inputText.length > 0) {
158+
fieldsNeedingEmbeddings.push({
159+
path: indexDef.path,
160+
rawValue: inputText,
161+
indexDef,
162+
});
163+
continue;
164+
}
165+
}
166+
167+
// Otherwise, check if the field exists in the document and is a string (raw text)
168+
const fieldValue = this.getFieldValue(document, indexDef.path);
169+
if (typeof fieldValue === "string") {
170+
fieldsNeedingEmbeddings.push({
171+
path: indexDef.path,
172+
rawValue: fieldValue,
173+
indexDef,
174+
});
175+
}
176+
}
177+
178+
// If no fields need embeddings, return document as-is
179+
if (fieldsNeedingEmbeddings.length === 0) {
180+
return document;
181+
}
182+
183+
// Check if embeddingParameters is provided
184+
if (!embeddingParameters) {
185+
const fieldPaths = fieldsNeedingEmbeddings.map((f) => f.path).join(", ");
186+
throw new MongoDBError(
187+
ErrorCodes.AtlasVectorSearchInvalidQuery,
188+
`Fields [${fieldPaths}] have vector search indexes but ${embeddingsInput ? "embeddingsInput contains text for these fields" : "contain raw text strings"}. The embeddingParameters parameter is required to generate embeddings for these fields.`
189+
);
190+
}
191+
192+
// Generate embeddings for all fields
193+
const embeddingsMap = new Map<string, number[]>();
194+
195+
for (const field of fieldsNeedingEmbeddings) {
196+
const embeddings = await this.session.vectorSearchEmbeddingsManager.generateEmbeddings({
197+
database,
198+
collection,
199+
path: field.path,
200+
rawValues: [field.rawValue],
201+
embeddingParameters,
202+
inputType: "document",
203+
});
204+
205+
if (embeddings.length > 0 && Array.isArray(embeddings[0])) {
206+
embeddingsMap.set(field.path, embeddings[0] as number[]);
207+
}
208+
}
209+
210+
// Replace raw string values with generated embeddings
211+
const processedDoc = this.cloneDocument(document);
212+
213+
for (const field of fieldsNeedingEmbeddings) {
214+
const embedding = embeddingsMap.get(field.path);
215+
if (embedding) {
216+
this.setFieldValue(processedDoc, field.path, embedding);
217+
}
218+
}
219+
220+
return processedDoc;
221+
}
222+
223+
private getFieldValue(document: Document, path: string): unknown {
224+
const parts = path.split(".");
225+
let current: unknown = document;
226+
227+
for (const part of parts) {
228+
if (current && typeof current === "object" && part in current) {
229+
current = (current as Record<string, unknown>)[part];
230+
} else {
231+
return undefined;
232+
}
233+
}
234+
235+
return current;
236+
}
237+
238+
private setFieldValue(document: Document, path: string, value: unknown): void {
239+
const parts = path.split(".");
240+
let current: Record<string, unknown> = document;
241+
242+
for (let i = 0; i < parts.length - 1; i++) {
243+
const part = parts[i];
244+
if (!part) {
245+
continue;
246+
}
247+
if (!(part in current) || typeof current[part] !== "object") {
248+
current[part] = {};
249+
}
250+
current = current[part] as Record<string, unknown>;
251+
}
252+
253+
const lastPart = parts[parts.length - 1];
254+
if (lastPart) {
255+
current[lastPart] = value;
256+
}
257+
}
258+
259+
private cloneDocument(document: Document): Document {
260+
return JSON.parse(JSON.stringify(document)) as Document;
261+
}
66262
}

0 commit comments

Comments
 (0)