From 3429110db31947a1c8bb2fcd59fa958ef10e6a13 Mon Sep 17 00:00:00 2001 From: Jason Nathan Date: Sat, 10 Feb 2024 07:42:14 +0800 Subject: [PATCH] community[minor]: Added `SQLiteRecordManager` (#4321) * created sqlite record manager and integration tests * Updated tests and implementation * Updated tests and implementation Updated implementation to make it cleaner * some tiny refactors * cr * update peer deps, add to config file * format --------- Co-authored-by: bracesproul --- libs/langchain-community/.gitignore | 4 + libs/langchain-community/langchain.config.js | 2 + libs/langchain-community/package.json | 19 ++ .../langchain-community/src/indexes/sqlite.ts | 234 ++++++++++++++++++ .../src/indexes/tests/sqlite.int.test.ts | 177 +++++++++++++ yarn.lock | 25 ++ 6 files changed, 461 insertions(+) create mode 100644 libs/langchain-community/src/indexes/sqlite.ts create mode 100644 libs/langchain-community/src/indexes/tests/sqlite.int.test.ts diff --git a/libs/langchain-community/.gitignore b/libs/langchain-community/.gitignore index ced5a2948a41..416c6d5382c5 100644 --- a/libs/langchain-community/.gitignore +++ b/libs/langchain-community/.gitignore @@ -666,6 +666,10 @@ indexes/memory.cjs indexes/memory.js indexes/memory.d.ts indexes/memory.d.cts +indexes/sqlite.cjs +indexes/sqlite.js +indexes/sqlite.d.ts +indexes/sqlite.d.cts util/convex.cjs util/convex.js util/convex.d.ts diff --git a/libs/langchain-community/langchain.config.js b/libs/langchain-community/langchain.config.js index c351659761b0..e0d347a64606 100644 --- a/libs/langchain-community/langchain.config.js +++ b/libs/langchain-community/langchain.config.js @@ -209,6 +209,7 @@ export const config = { "indexes/base": "indexes/base", "indexes/postgres": "indexes/postgres", "indexes/memory": "indexes/memory", + "indexes/sqlite": "indexes/sqlite", // utils "util/convex": "utils/convex", "utils/event_source_parse": "utils/event_source_parse", @@ -334,6 +335,7 @@ export const config = { "util/convex", // indexes "indexes/postgres", + "indexes/sqlite", ], packageSuffix: "community", tsConfigPath: resolve("./tsconfig.json"), diff --git a/libs/langchain-community/package.json b/libs/langchain-community/package.json index 956301c1a366..920076c332f8 100644 --- a/libs/langchain-community/package.json +++ b/libs/langchain-community/package.json @@ -93,6 +93,7 @@ "@tensorflow/tfjs-converter": "^3.6.0", "@tensorflow/tfjs-core": "^3.6.0", "@tsconfig/recommended": "^1.0.2", + "@types/better-sqlite3": "^7.6.9", "@types/flat": "^5.0.2", "@types/html-to-text": "^9", "@types/jsdom": "^21.1.1", @@ -111,6 +112,7 @@ "@xata.io/client": "^0.28.0", "@xenova/transformers": "^2.5.4", "@zilliz/milvus2-sdk-node": ">=2.2.11", + "better-sqlite3": "^9.4.0", "cassandra-driver": "^4.7.2", "chromadb": "^1.5.3", "closevector-common": "0.1.3", @@ -210,6 +212,7 @@ "@xata.io/client": "^0.28.0", "@xenova/transformers": "^2.5.4", "@zilliz/milvus2-sdk-node": ">=2.2.7", + "better-sqlite3": "^9.4.0", "cassandra-driver": "^4.7.2", "chromadb": "*", "closevector-common": "0.1.3", @@ -381,6 +384,9 @@ "@zilliz/milvus2-sdk-node": { "optional": true }, + "better-sqlite3": { + "optional": true + }, "cassandra-driver": { "optional": true }, @@ -2000,6 +2006,15 @@ "import": "./indexes/memory.js", "require": "./indexes/memory.cjs" }, + "./indexes/sqlite": { + "types": { + "import": "./indexes/sqlite.d.ts", + "require": "./indexes/sqlite.d.cts", + "default": "./indexes/sqlite.d.ts" + }, + "import": "./indexes/sqlite.js", + "require": "./indexes/sqlite.cjs" + }, "./util/convex": { "types": { "import": "./util/convex.d.ts", @@ -2699,6 +2714,10 @@ "indexes/memory.js", "indexes/memory.d.ts", "indexes/memory.d.cts", + "indexes/sqlite.cjs", + "indexes/sqlite.js", + "indexes/sqlite.d.ts", + "indexes/sqlite.d.cts", "util/convex.cjs", "util/convex.js", "util/convex.d.ts", diff --git a/libs/langchain-community/src/indexes/sqlite.ts b/libs/langchain-community/src/indexes/sqlite.ts new file mode 100644 index 000000000000..78d7f8faf34b --- /dev/null +++ b/libs/langchain-community/src/indexes/sqlite.ts @@ -0,0 +1,234 @@ +// eslint-disable-next-line import/no-extraneous-dependencies +import Database, { Database as DatabaseType, Statement } from "better-sqlite3"; +import { + ListKeyOptions, + RecordManagerInterface, + UpdateOptions, +} from "./base.js"; + +interface TimeRow { + epoch: number; +} + +interface KeyRecord { + key: string; +} + +/** + * Options for configuring the SQLiteRecordManager class. + */ +export type SQLiteRecordManagerOptions = { + /** + * The file path of the SQLite database. + * One of either `localPath` or `connectionString` is required. + */ + localPath?: string; + /** + * The connection string of the SQLite database. + * One of either `localPath` or `connectionString` is required. + */ + connectionString?: string; + /** + * The name of the table in the SQLite database. + */ + tableName: string; +}; + +export class SQLiteRecordManager implements RecordManagerInterface { + lc_namespace = ["langchain", "recordmanagers", "sqlite"]; + + tableName: string; + + db: DatabaseType; + + namespace: string; + + constructor(namespace: string, config: SQLiteRecordManagerOptions) { + const { localPath, connectionString, tableName } = config; + if (!connectionString && !localPath) { + throw new Error( + "One of either `localPath` or `connectionString` is required." + ); + } + if (connectionString && localPath) { + throw new Error( + "Only one of either `localPath` or `connectionString` is allowed." + ); + } + this.namespace = namespace; + this.tableName = tableName; + this.db = new Database(connectionString ?? localPath); + } + + async createSchema(): Promise { + try { + this.db.exec(` +CREATE TABLE IF NOT EXISTS "${this.tableName}" ( + uuid TEXT PRIMARY KEY DEFAULT (lower(hex(randomblob(16)))), + key TEXT NOT NULL, + namespace TEXT NOT NULL, + updated_at REAL NOT NULL, + group_id TEXT, + UNIQUE (key, namespace) +); +CREATE INDEX IF NOT EXISTS updated_at_index ON "${this.tableName}" (updated_at); +CREATE INDEX IF NOT EXISTS key_index ON "${this.tableName}" (key); +CREATE INDEX IF NOT EXISTS namespace_index ON "${this.tableName}" (namespace); +CREATE INDEX IF NOT EXISTS group_id_index ON "${this.tableName}" (group_id);`); + } catch (error) { + console.error("Error creating schema"); + throw error; // Re-throw the error to let the caller handle it + } + } + + async getTime(): Promise { + try { + const statement: Statement<[]> = this.db.prepare( + "SELECT strftime('%s', 'now') AS epoch" + ); + const { epoch } = statement.get() as TimeRow; + return Number(epoch); + } catch (error) { + console.error("Error getting time in SQLiteRecordManager:"); + throw error; + } + } + + async update(keys: string[], updateOptions?: UpdateOptions): Promise { + if (keys.length === 0) { + return; + } + + const updatedAt = await this.getTime(); + const { timeAtLeast, groupIds: _groupIds } = updateOptions ?? {}; + + if (timeAtLeast && updatedAt < timeAtLeast) { + throw new Error( + `Time sync issue with database ${updatedAt} < ${timeAtLeast}` + ); + } + + const groupIds = _groupIds ?? keys.map(() => null); + + if (groupIds.length !== keys.length) { + throw new Error( + `Number of keys (${keys.length}) does not match number of group_ids (${groupIds.length})` + ); + } + + const recordsToUpsert = keys.map((key, i) => [ + key, + this.namespace, + updatedAt, + groupIds[i] ?? null, // Ensure groupIds[i] is null if undefined + ]); + + // Consider using a transaction for batch operations + const updateTransaction = this.db.transaction(() => { + for (const row of recordsToUpsert) { + this.db + .prepare( + ` +INSERT INTO "${this.tableName}" (key, namespace, updated_at, group_id) +VALUES (?, ?, ?, ?) +ON CONFLICT (key, namespace) DO UPDATE SET updated_at = excluded.updated_at` + ) + .run(...row); + } + }); + updateTransaction(); + } + + async exists(keys: string[]): Promise { + if (keys.length === 0) { + return []; + } + + // Prepare the placeholders and the query + const placeholders = keys.map(() => `?`).join(", "); + const sql = ` +SELECT key +FROM "${this.tableName}" +WHERE namespace = ? AND key IN (${placeholders})`; + + // Initialize an array to fill with the existence checks + const existsArray = new Array(keys.length).fill(false); + + try { + // Execute the query + const rows = this.db + .prepare(sql) + .all(this.namespace, ...keys) as KeyRecord[]; + // Create a set of existing keys for faster lookup + const existingKeysSet = new Set(rows.map((row) => row.key)); + // Map the input keys to booleans indicating if they exist + keys.forEach((key, index) => { + existsArray[index] = existingKeysSet.has(key); + }); + return existsArray; + } catch (error) { + console.error("Error checking existence of keys"); + throw error; // Allow the caller to handle the error + } + } + + async listKeys(options?: ListKeyOptions): Promise { + const { before, after, limit, groupIds } = options ?? {}; + let query = `SELECT key FROM "${this.tableName}" WHERE namespace = ?`; + const values: (string | number | string[])[] = [this.namespace]; + + if (before) { + query += ` AND updated_at < ?`; + values.push(before); + } + + if (after) { + query += ` AND updated_at > ?`; + values.push(after); + } + + if (limit) { + query += ` LIMIT ?`; + values.push(limit); + } + + if (groupIds && Array.isArray(groupIds)) { + query += ` AND group_id IN (${groupIds + .filter((gid) => gid !== null) + .map(() => "?") + .join(", ")})`; + values.push(...groupIds.filter((gid): gid is string => gid !== null)); + } + + query += ";"; + + // Directly using try/catch with async/await for cleaner flow + try { + const result = this.db.prepare(query).all(...values) as { key: string }[]; + return result.map((row) => row.key); + } catch (error) { + console.error("Error listing keys."); + throw error; // Re-throw the error to be handled by the caller + } + } + + async deleteKeys(keys: string[]): Promise { + if (keys.length === 0) { + return; + } + + const placeholders = keys.map(() => "?").join(", "); + const query = `DELETE FROM "${this.tableName}" WHERE namespace = ? AND key IN (${placeholders});`; + const values = [this.namespace, ...keys].map((v) => + typeof v !== "string" ? `${v}` : v + ); + + // Directly using try/catch with async/await for cleaner flow + try { + this.db.prepare(query).run(...values); + } catch (error) { + console.error("Error deleting keys"); + throw error; // Re-throw the error to be handled by the caller + } + } +} diff --git a/libs/langchain-community/src/indexes/tests/sqlite.int.test.ts b/libs/langchain-community/src/indexes/tests/sqlite.int.test.ts new file mode 100644 index 000000000000..251dc08cbcd1 --- /dev/null +++ b/libs/langchain-community/src/indexes/tests/sqlite.int.test.ts @@ -0,0 +1,177 @@ +import { describe, expect, test, jest } from "@jest/globals"; +import { SQLiteRecordManager } from "../sqlite.js"; + +describe("SQLiteRecordManager", () => { + const tableName = "upsertion_record"; + let recordManager: SQLiteRecordManager; + + beforeAll(async () => { + const localPath = ":memory:"; + recordManager = new SQLiteRecordManager("test", { + tableName, + localPath, + }); + await recordManager.createSchema(); + }); + + afterEach(async () => { + recordManager.db.exec(`DELETE FROM "${tableName}"`); + await recordManager.createSchema(); + }); + + afterAll(() => { + recordManager.db.close(); + }); + + test("Test upsertion", async () => { + const keys = ["a", "b", "c"]; + await recordManager.update(keys); + const readKeys = await recordManager.listKeys(); + expect(readKeys).toEqual(expect.arrayContaining(keys)); + expect(readKeys).toHaveLength(keys.length); + }); + + test("Test upsertion with timeAtLeast", async () => { + // Mock getTime to return 100. + const unmockedGetTime = recordManager.getTime; + recordManager.getTime = jest.fn(() => Promise.resolve(100)); + + const keys = ["a", "b", "c"]; + await expect( + recordManager.update(keys, { timeAtLeast: 110 }) + ).rejects.toThrowError(); + const readKeys = await recordManager.listKeys(); + expect(readKeys).toHaveLength(0); + + // Set getTime back to normal. + recordManager.getTime = unmockedGetTime; + }); + + interface RecordRow { + // Define the structure of the rows returned from the database query + // Adjust the properties based on your table schema + id: number; + key: string; + updated_at: number; + group_id: string; + } + + test("Test update timestamp", async () => { + const unmockedGetTime = recordManager.getTime; + recordManager.getTime = jest.fn(() => Promise.resolve(100)); + try { + const keys = ["a", "b", "c"]; + await recordManager.update(keys); + const rows = recordManager.db + .prepare(`SELECT * FROM "${tableName}"`) + .all() as RecordRow[]; + rows.forEach((row) => expect(row.updated_at).toEqual(100)); + + recordManager.getTime = jest.fn(() => Promise.resolve(200)); + await recordManager.update(keys); + const rows2 = (await recordManager.db + .prepare(`SELECT * FROM "${tableName}"`) + .all()) as RecordRow[]; + rows2.forEach((row) => expect(row.updated_at).toEqual(200)); + } finally { + recordManager.getTime = unmockedGetTime; + } + }); + + test("Test update with groupIds", async () => { + const keys = ["a", "b", "c"]; + await recordManager.update(keys, { + groupIds: ["group1", "group1", "group2"], + }); + const rows = recordManager.db + .prepare(`SELECT * FROM "${tableName}" WHERE group_id = ?`) + .all("group1") as RecordRow[]; + expect(rows.length).toEqual(2); + rows.forEach((row) => expect(row.group_id).toEqual("group1")); + }); + + test("Exists", async () => { + const keys = ["a", "b", "c"]; + await recordManager.update(keys); + const exists = await recordManager.exists(keys); + expect(exists).toEqual([true, true, true]); + + const nonExistentKeys = ["d", "e", "f"]; + const nonExists = await recordManager.exists(nonExistentKeys); + expect(nonExists).toEqual([false, false, false]); + + const mixedKeys = ["a", "e", "c"]; + const mixedExists = await recordManager.exists(mixedKeys); + expect(mixedExists).toEqual([true, false, true]); + }); + + test("Delete", async () => { + const keys = ["a", "b", "c"]; + await recordManager.update(keys); + await recordManager.deleteKeys(["a", "c"]); + const readKeys = await recordManager.listKeys(); + expect(readKeys).toEqual(["b"]); + }); + + test("List keys", async () => { + const unmockedGetTime = recordManager.getTime; + recordManager.getTime = jest.fn(() => Promise.resolve(100)); + try { + const keys = ["a", "b", "c"]; + await recordManager.update(keys); + const readKeys = await recordManager.listKeys(); + expect(readKeys).toEqual(expect.arrayContaining(keys)); + expect(readKeys).toHaveLength(keys.length); + + // All keys inserted after 90: should be all keys + const readKeysAfterInsertedAfter = await recordManager.listKeys({ + after: 90, + }); + expect(readKeysAfterInsertedAfter).toEqual(expect.arrayContaining(keys)); + + // All keys inserted after 110: should be none + const readKeysAfterInsertedBefore = await recordManager.listKeys({ + after: 110, + }); + expect(readKeysAfterInsertedBefore).toEqual([]); + + // All keys inserted before 110: should be all keys + const readKeysBeforeInsertedBefore = await recordManager.listKeys({ + before: 110, + }); + expect(readKeysBeforeInsertedBefore).toEqual( + expect.arrayContaining(keys) + ); + + // All keys inserted before 90: should be none + const readKeysBeforeInsertedAfter = await recordManager.listKeys({ + before: 90, + }); + expect(readKeysBeforeInsertedAfter).toEqual([]); + + // Set one key to updated at 120 and one at 80 + recordManager.getTime = jest.fn(() => Promise.resolve(120)); + await recordManager.update(["a"]); + recordManager.getTime = jest.fn(() => Promise.resolve(80)); + await recordManager.update(["b"]); + + // All keys updated after 90 and before 110: should only be "c" now + const readKeysBeforeAndAfter = await recordManager.listKeys({ + before: 110, + after: 90, + }); + expect(readKeysBeforeAndAfter).toEqual(["c"]); + } finally { + recordManager.getTime = unmockedGetTime; + } + }); + + test("List keys with groupIds", async () => { + const keys = ["a", "b", "c"]; + await recordManager.update(keys, { + groupIds: ["group1", "group1", "group2"], + }); + const readKeys = await recordManager.listKeys({ groupIds: ["group1"] }); + expect(readKeys).toEqual(["a", "b"]); + }); +}); diff --git a/yarn.lock b/yarn.lock index 79c51a3d8d62..cb5d2110618a 100644 --- a/yarn.lock +++ b/yarn.lock @@ -8911,6 +8911,7 @@ __metadata: "@tensorflow/tfjs-converter": ^3.6.0 "@tensorflow/tfjs-core": ^3.6.0 "@tsconfig/recommended": ^1.0.2 + "@types/better-sqlite3": ^7.6.9 "@types/flat": ^5.0.2 "@types/html-to-text": ^9 "@types/jsdom": ^21.1.1 @@ -8929,6 +8930,7 @@ __metadata: "@xata.io/client": ^0.28.0 "@xenova/transformers": ^2.5.4 "@zilliz/milvus2-sdk-node": ">=2.2.11" + better-sqlite3: ^9.4.0 cassandra-driver: ^4.7.2 chromadb: ^1.5.3 closevector-common: 0.1.3 @@ -9031,6 +9033,7 @@ __metadata: "@xata.io/client": ^0.28.0 "@xenova/transformers": ^2.5.4 "@zilliz/milvus2-sdk-node": ">=2.2.7" + better-sqlite3: ^9.4.0 cassandra-driver: ^4.7.2 chromadb: "*" closevector-common: 0.1.3 @@ -9157,6 +9160,8 @@ __metadata: optional: true "@zilliz/milvus2-sdk-node": optional: true + better-sqlite3: + optional: true cassandra-driver: optional: true chromadb: @@ -13664,6 +13669,15 @@ __metadata: languageName: node linkType: hard +"@types/better-sqlite3@npm:^7.6.9": + version: 7.6.9 + resolution: "@types/better-sqlite3@npm:7.6.9" + dependencies: + "@types/node": "*" + checksum: 6572076639dde1e65ad8fe0e319e797fa40793ca805ec39aa1d072d3f145f218775eafd27d7266fb4e42a6291d8b8b836278e7880b15ef728c750dfcbba2ee52 + languageName: node + linkType: hard + "@types/body-parser@npm:*": version: 1.19.2 resolution: "@types/body-parser@npm:1.19.2" @@ -16566,6 +16580,17 @@ __metadata: languageName: node linkType: hard +"better-sqlite3@npm:^9.4.0": + version: 9.4.0 + resolution: "better-sqlite3@npm:9.4.0" + dependencies: + bindings: ^1.5.0 + node-gyp: latest + prebuild-install: ^7.1.1 + checksum: a1a470fae20dfba82d6e74ae90b35ea8996c60922e95574162732d6e076e84c0c90fc4ff77ab8c27554671899eb15f284e2c8de5e4ee406aa9f7eb170eca5bee + languageName: node + linkType: hard + "big-integer@npm:^1.6.44": version: 1.6.51 resolution: "big-integer@npm:1.6.51"