Skip to content

Commit

Permalink
revert changes
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Sep 3, 2024
1 parent 214df9b commit 7d3b5aa
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 148 deletions.
15 changes: 7 additions & 8 deletions clients/js/src/ChromaClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { AdminClient } from "./AdminClient";
import { authOptionsToAuthProvider, ClientAuthProvider } from "./auth";
import { chromaFetch } from "./ChromaFetch";
import { DefaultEmbeddingFunction } from "./embeddings/DefaultEmbeddingFunction";
import { ChromaConnectionError, ChromaServerError } from "./Errors";
import {
Configuration,
ApiApi as DefaultApi,
Expand All @@ -17,6 +18,7 @@ import type {
CreateCollectionParams,
DeleteCollectionParams,
DeleteParams,
Embedding,
Embeddings,
GetCollectionParams,
GetOrCreateCollectionParams,
Expand All @@ -31,7 +33,6 @@ import type {
} from "./types";
import {
prepareRecordRequest,
prepareRecordRequestWithIDsOptional,
toArray,
toArrayOfArrays,
validateTenantDatabase,
Expand Down Expand Up @@ -396,7 +397,7 @@ export class ChromaClient {
/**
* Add items to the collection
* @param {Object} params - The parameters for the query.
* @param {ID | IDs} [params.ids] - Optional IDs of the items to add.
* @param {ID | IDs} [params.ids] - IDs of the items to add.
* @param {Embedding | Embeddings} [params.embeddings] - Optional embeddings of the items to add.
* @param {Metadata | Metadatas} [params.metadatas] - Optional metadata of the items to add.
* @param {Document | Documents} [params.documents] - Optional documents of the items to add.
Expand All @@ -415,20 +416,18 @@ export class ChromaClient {
async addRecords(
collection: Collection,
params: AddRecordsParams,
): Promise<AddResponse> {
): Promise<void> {
await this.init();

const resp = (await this.api.add(
await this.api.add(
collection.id,
// TODO: For some reason the auto generated code requires metadata to be defined here.
(await prepareRecordRequestWithIDsOptional(
(await prepareRecordRequest(
params,
collection.embeddingFunction,
)) as GeneratedApi.AddEmbedding,
this.api.options,
)) as AddResponse;

return resp;
);
}

/**
Expand Down
60 changes: 2 additions & 58 deletions clients/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ export type MultiQueryResponse = {

export type QueryResponse = SingleQueryResponse | MultiQueryResponse;

export type AddResponse = {
ids: IDs;
};
export type AddResponse = {};

export interface Collection {
name: string;
Expand Down Expand Up @@ -166,28 +164,13 @@ export type BaseRecordOperationParams = {
documents?: Document | Documents;
};

export type BaseRecordOperationParamsWithIDsOptional = {
ids?: ID | IDs;
embeddings?: Embedding | Embeddings;
metadatas?: Metadata | Metadatas;
documents?: Document | Documents;
};

export type SingleRecordOperationParams = BaseRecordOperationParams & {
ids: ID;
embeddings?: Embedding;
metadatas?: Metadata;
documents?: Document;
};

export type SingleRecordOperationParamsWithIDsOptional =
BaseRecordOperationParamsWithIDsOptional & {
ids?: ID;
embeddings?: Embedding;
metadatas?: Metadata;
documents?: Document;
};

type SingleEmbeddingRecordOperationParams = SingleRecordOperationParams & {
embeddings: Embedding;
};
Expand All @@ -200,31 +183,13 @@ export type SingleAddRecordOperationParams =
| SingleEmbeddingRecordOperationParams
| SingleContentRecordOperationParams;

type SingleEmbeddingRecordOperationParamsWithOptionalIDs =
BaseRecordOperationParamsWithIDsOptional & {
embeddings: Embedding;
};

type SingleContentRecordOperationParamsWithOptionalIDs =
BaseRecordOperationParamsWithIDsOptional & {
documents: Document;
};

export type MultiRecordOperationParams = BaseRecordOperationParams & {
ids: IDs;
embeddings?: Embeddings;
metadatas?: Metadatas;
documents?: Documents;
};

export type MultiRecordOperationParamsWithIDsOptional =
BaseRecordOperationParamsWithIDsOptional & {
ids?: IDs;
embeddings?: Embeddings;
metadatas?: Metadatas;
documents?: Documents;
};

type MultiEmbeddingRecordOperationParams = MultiRecordOperationParams & {
embeddings: Embeddings;
};
Expand All @@ -233,36 +198,15 @@ type MultiContentRecordOperationParams = MultiRecordOperationParams & {
documents: Documents;
};

type MultiEmbeddingRecordOperationParamsWithOptionalIDs =
MultiRecordOperationParamsWithIDsOptional & {
embeddings: Embeddings;
};

type MultiContentRecordOperationParamsWithOptionalIDs =
MultiRecordOperationParamsWithIDsOptional & {
documents: Documents;
};

export type SingleAddRecordOperationParamsWithOptionalIDs =
| SingleEmbeddingRecordOperationParamsWithOptionalIDs
| SingleContentRecordOperationParamsWithOptionalIDs;

export type MultiAddRecordsOperationParamsWithOptionalIDs =
| MultiEmbeddingRecordOperationParamsWithOptionalIDs
| MultiContentRecordOperationParamsWithOptionalIDs;

export type MultiAddRecordsOperationParams =
| MultiEmbeddingRecordOperationParams
| MultiContentRecordOperationParams;

export type AddRecordsParams =
| SingleAddRecordOperationParamsWithOptionalIDs
| MultiAddRecordsOperationParamsWithOptionalIDs;

export type UpsertRecordsParams =
| SingleAddRecordOperationParams
| MultiAddRecordsOperationParams;

export type UpsertRecordsParams = AddRecordsParams;
export type UpdateRecordsParams =
| MultiRecordOperationParams
| SingleRecordOperationParams;
Expand Down
93 changes: 23 additions & 70 deletions clients/js/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@ import { ChromaConnectionError } from "./Errors";
import { IEmbeddingFunction } from "./embeddings/IEmbeddingFunction";
import {
AddRecordsParams,
BaseRecordOperationParamsWithIDsOptional,
BaseRecordOperationParams,
Collection,
Embeddings,
Documents,
Metadata,
MultiRecordOperationParams,
MultiRecordOperationParamsWithIDsOptional,
UpdateRecordsParams,
UpsertRecordsParams,
} from "./types";

// a function to convert a non-Array object to an Array
Expand Down Expand Up @@ -86,10 +82,10 @@ export function isBrowser() {
}

function arrayifyParams(
params: BaseRecordOperationParamsWithIDsOptional,
): MultiRecordOperationParamsWithIDsOptional {
params: BaseRecordOperationParams,
): MultiRecordOperationParams {
return {
ids: params.ids !== undefined ? toArray(params.ids) : undefined,
ids: toArray(params.ids),
embeddings: params.embeddings
? toArrayOfArrays(params.embeddings)
: undefined,
Expand All @@ -101,72 +97,16 @@ function arrayifyParams(
}

export async function prepareRecordRequest(
reqParams: UpsertRecordsParams | UpdateRecordsParams,
reqParams: AddRecordsParams | UpdateRecordsParams,
embeddingFunction: IEmbeddingFunction,
update?: true,
): Promise<MultiRecordOperationParams> {
const {
ids = [],
embeddings,
metadatas,
documents,
} = arrayifyParams(reqParams);

if (!embeddings && !documents && !update) {
throw new Error("embeddings and documents cannot both be undefined");
}

validateIDs(ids);

const embeddingsArray = await computeEmbeddings(
embeddingFunction,
embeddings,
documents,
update,
);

return {
ids,
metadatas,
documents,
embeddings: embeddingsArray,
};
}

export async function prepareRecordRequestWithIDsOptional(
reqParams: AddRecordsParams,
embeddingFunction: IEmbeddingFunction,
): Promise<MultiRecordOperationParamsWithIDsOptional> {
const { ids, embeddings, metadatas, documents } = arrayifyParams(reqParams);

if (!embeddings && !documents) {
if (!embeddings && !documents && !update) {
throw new Error("embeddings and documents cannot both be undefined");
}

if (ids) {
validateIDs(ids);
}

const embeddingsArray = await computeEmbeddings(
embeddingFunction,
embeddings,
documents,
);

return {
ids,
metadatas,
documents,
embeddings: embeddingsArray,
};
}

async function computeEmbeddings(
embeddingFunction: IEmbeddingFunction,
embeddings?: Embeddings,
documents?: Documents,
update?: true,
): Promise<Embeddings | undefined> {
const embeddingsArray = embeddings
? embeddings
: documents
Expand All @@ -177,10 +117,6 @@ async function computeEmbeddings(
throw new Error("Failed to generate embeddings for your request.");
}

return embeddingsArray;
}

function validateIDs(ids: string[]) {
for (let i = 0; i < ids.length; i += 1) {
if (typeof ids[i] !== "string") {
throw new Error(
Expand All @@ -189,6 +125,16 @@ function validateIDs(ids: string[]) {
}
}

if (
(embeddingsArray !== undefined && ids.length !== embeddingsArray.length) ||
(metadatas !== undefined && ids.length !== metadatas.length) ||
(documents !== undefined && ids.length !== documents.length)
) {
throw new Error(
"ids, embeddings, metadatas, and documents must all be the same length",
);
}

const uniqueIds = new Set(ids);
if (uniqueIds.size !== ids.length) {
const duplicateIds = ids.filter(
Expand All @@ -198,6 +144,13 @@ function validateIDs(ids: string[]) {
`ID's must be unique, found duplicates for: ${duplicateIds}`,
);
}

return {
ids,
metadatas,
documents,
embeddings: embeddingsArray,
};
}

function notifyUserOfLegacyMethod(newMethod: string) {
Expand Down
20 changes: 8 additions & 12 deletions clients/js/test/add.collections.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,23 @@ describe("add collections", () => {
const ids = IDS.concat(["test1"]);
const embeddings = EMBEDDINGS.concat([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]);
const metadatas = METADATAS.concat([{ test: "test1", float_value: 0.1 }]);
expect(async () => {
try {
await client.addRecords(collection, { ids, embeddings, metadatas });
}).rejects.toThrow("found duplicates");
});

test("It should generate IDs if not provided", async () => {
const collection = await client.createCollection({ name: "test" });
const embeddings = EMBEDDINGS.concat([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]);
const metadatas = METADATAS.concat([{ test: "test1", float_value: 0.1 }]);
const resp = await client.addRecords(collection, { embeddings, metadatas });
expect(resp.ids.length).toEqual(4);
} catch (e: any) {
expect(e.message).toMatch("duplicates");
}
});

test("should error on empty embedding", async () => {
const collection = await client.createCollection({ name: "test" });
const ids = ["id1"];
const embeddings = [[]];
const metadatas = [{ test: "test1", float_value: 0.1 }];
expect(async () => {
try {
await client.addRecords(collection, { ids, embeddings, metadatas });
}).rejects.toThrow("got empty embedding at pos");
} catch (e: any) {
expect(e.message).toMatch("got empty embedding at pos");
}
});

if (!process.env.OLLAMA_SERVER_URL) {
Expand Down

0 comments on commit 7d3b5aa

Please sign in to comment.