Skip to content

Commit fcfdd32

Browse files
committed
chore: update types, add tests for assertVectorSearchIndexExists
1 parent dc2f207 commit fcfdd32

File tree

4 files changed

+154
-10
lines changed

4 files changed

+154
-10
lines changed

src/common/search/embeddingsProvider.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { createFetch } from "@mongodb-js/devtools-proxy-support";
77
import { z } from "zod";
88

99
type EmbeddingsInput = string;
10-
type Embeddings = number[];
10+
type Embeddings = number[] | unknown[];
1111
export type EmbeddingParameters = {
1212
inputType: "query" | "document";
1313
};

src/common/search/vectorSearchEmbeddingsManager.ts

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,14 @@ export class VectorSearchEmbeddingsManager {
108108
{ database, collection }: { database: string; collection: string },
109109
documents: Document[]
110110
): Promise<void> {
111-
const embeddingValidationResults = await Promise.all(
112-
documents.map((document) => this.findFieldsWithWrongEmbeddings({ database, collection }, document))
113-
);
114-
const embeddingValidations = new Set(embeddingValidationResults.flat());
115-
116-
if (embeddingValidations.size > 0) {
117-
const embeddingValidationMessages = Array.from(embeddingValidations).map(
111+
const embeddingValidationResults = (
112+
await Promise.all(
113+
documents.map((document) => this.findFieldsWithWrongEmbeddings({ database, collection }, document))
114+
)
115+
).flat();
116+
117+
if (embeddingValidationResults.length > 0) {
118+
const embeddingValidationMessages = embeddingValidationResults.map(
118119
(validation) =>
119120
`- Field ${validation.path} is an embedding with ${validation.expectedNumDimensions} dimensions and ${validation.expectedQuantization}` +
120121
` quantization, and the provided value is not compatible. Actual dimensions: ${validation.actualNumDimensions}, ` +
@@ -293,7 +294,7 @@ export class VectorSearchEmbeddingsManager {
293294
rawValues: string[];
294295
embeddingParameters: SupportedEmbeddingParameters;
295296
inputType: EmbeddingParameters["inputType"];
296-
}): Promise<number[][]> {
297+
}): Promise<unknown[][]> {
297298
const provider = await this.atlasSearchEnabledProvider();
298299
if (!provider) {
299300
throw new MongoDBError(

src/tools/mongodb/read/aggregate.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ export class AggregateTool extends MongoDBToolBase {
298298
// $vectorSearch.queryVector can be a BSON.Binary: that it's not either number or an array.
299299
// It's not exactly valid from the LLM perspective (they can't provide binaries).
300300
// That's why we overwrite the stage in an untyped way, as what we expose and what LLMs can use is different.
301-
vectorSearchStage.queryVector = embeddings;
301+
vectorSearchStage.queryVector = embeddings as string | number[];
302302
}
303303
}
304304

tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,149 @@ describe("VectorSearchEmbeddingsManager", () => {
390390
});
391391
});
392392

393+
describe("assertFieldsHaveCorrectEmbeddings", () => {
394+
it("does not throw for invalid documents when validation is disabled", async () => {
395+
const embeddings = new VectorSearchEmbeddingsManager(
396+
embeddingValidationDisabled,
397+
connectionManager,
398+
embeddingConfig
399+
);
400+
await expect(
401+
embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
402+
{ embedding_field: "some text" },
403+
{ embedding_field: [1, 2, 3] },
404+
])
405+
).resolves.not.toThrow();
406+
});
407+
408+
describe("when validation is enabled", () => {
409+
let embeddings: VectorSearchEmbeddingsManager;
410+
411+
beforeEach(() => {
412+
embeddings = new VectorSearchEmbeddingsManager(
413+
embeddingValidationEnabled,
414+
connectionManager,
415+
embeddingConfig
416+
);
417+
});
418+
419+
it("does not throw when all documents are valid", async () => {
420+
await expect(
421+
embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
422+
{ embedding_field: [1, 2, 3, 4, 5, 6, 7, 8] },
423+
{ embedding_field: [9, 10, 11, 12, 13, 14, 15, 16] },
424+
{ field: "no embeddings here" },
425+
])
426+
).resolves.not.toThrow();
427+
});
428+
429+
it("throws error when one document has wrong dimensions", async () => {
430+
await expect(
431+
embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
432+
{ embedding_field: [1, 2, 3] },
433+
])
434+
).rejects.toThrow(/Field embedding_field is an embedding with 8 dimensions/);
435+
});
436+
437+
it("throws error when one document has wrong type", async () => {
438+
await expect(
439+
embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
440+
{ embedding_field: "some text" },
441+
])
442+
).rejects.toThrow(/Field embedding_field is an embedding with 8 dimensions/);
443+
});
444+
445+
it("throws error when one document has non-numeric values", async () => {
446+
await expect(
447+
embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
448+
{ embedding_field: ["1", "2", "3", "4", "5", "6", "7", "8"] },
449+
])
450+
).rejects.toThrow(/Field embedding_field is an embedding with 8 dimensions/);
451+
});
452+
453+
it("throws error with details about dimension mismatch", async () => {
454+
await expect(
455+
embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
456+
{ embedding_field: [1, 2, 3] },
457+
])
458+
).rejects.toThrow(/Actual dimensions: 3/);
459+
});
460+
461+
it("throws error with details about quantization", async () => {
462+
await expect(
463+
embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
464+
{ embedding_field: [1, 2, 3] },
465+
])
466+
).rejects.toThrow(/actual quantization: scalar/);
467+
});
468+
469+
it("throws error with details about error type", async () => {
470+
await expect(
471+
embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
472+
{ embedding_field: [1, 2, 3] },
473+
])
474+
).rejects.toThrow(/Error: dimension-mismatch/);
475+
});
476+
477+
it("throws error when multiple documents have invalid embeddings", async () => {
478+
try {
479+
await embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
480+
{ embedding_field: [1, 2, 3] },
481+
{ embedding_field: "some text" },
482+
]);
483+
expect.fail("Should have thrown an error");
484+
} catch (error) {
485+
expect((error as Error).message).toContain("Field embedding_field");
486+
expect((error as Error).message).toContain("dimension-mismatch");
487+
}
488+
});
489+
490+
it("handles documents with multiple invalid fields", async () => {
491+
try {
492+
await embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
493+
{
494+
embedding_field: [1, 2, 3],
495+
embedding_field_binary: "not binary",
496+
},
497+
]);
498+
expect.fail("Should have thrown an error");
499+
} catch (error) {
500+
expect((error as Error).message).toContain("Field embedding_field");
501+
expect((error as Error).message).toContain("Field embedding_field_binary");
502+
}
503+
});
504+
505+
it("handles mix of valid and invalid documents", async () => {
506+
try {
507+
await embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
508+
{ embedding_field: [1, 2, 3, 4, 5, 6, 7, 8] }, // valid
509+
{ embedding_field: [1, 2, 3] }, // invalid
510+
{ valid_field: "no embeddings" }, // valid (no embedding field)
511+
]);
512+
expect.fail("Should have thrown an error");
513+
} catch (error) {
514+
expect((error as Error).message).toContain("Field embedding_field");
515+
expect((error as Error).message).toContain("dimension-mismatch");
516+
expect((error as Error).message).not.toContain("Field valid_field");
517+
}
518+
});
519+
520+
it("handles nested fields with validation errors", async () => {
521+
await expect(
522+
embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [
523+
{ a: { nasty: { scalar: { field: [1, 2, 3] } } } },
524+
])
525+
).rejects.toThrow(/Field a\.nasty\.scalar\.field/);
526+
});
527+
528+
it("handles empty document array", async () => {
529+
await expect(
530+
embeddings.assertFieldsHaveCorrectEmbeddings({ database, collection }, [])
531+
).resolves.not.toThrow();
532+
});
533+
});
534+
});
535+
393536
describe("generate embeddings", () => {
394537
const embeddingToGenerate = {
395538
database: "mydb",

0 commit comments

Comments
 (0)