Skip to content

Commit

Permalink
Bugfix/SQLite agent memory node (#3650)
Browse files Browse the repository at this point in the history
* add dedicated agent memory nodes

* sqlite agent memory fix

* Update pnpm-lock.yaml
  • Loading branch information
HenryHengZJ authored Dec 6, 2024
1 parent cadc3b8 commit 7d1234a
Showing 1 changed file with 75 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,49 +1,47 @@
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
import { RunnableConfig } from '@langchain/core/runnables'
import { BaseMessage } from '@langchain/core/messages'
import { DataSource, QueryRunner } from 'typeorm'
import { DataSource } from 'typeorm'
import { CheckpointTuple, SaverOptions, SerializerProtocol } from '../interface'
import { IMessage, MemoryMethods } from '../../../../src/Interface'
import { mapChatMessageToBaseMessage } from '../../../../src/utils'

export class SqliteSaver extends BaseCheckpointSaver implements MemoryMethods {
protected isSetup: boolean

datasource: DataSource

queryRunner: QueryRunner

config: SaverOptions

threadId: string

tableName = 'checkpoints'

constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
super(serde)
this.config = config
const { datasourceOptions, threadId } = config
const { threadId } = config
this.threadId = threadId
this.datasource = new DataSource(datasourceOptions)
}

private async setup(): Promise<void> {
private async getDataSource(): Promise<DataSource> {
const { datasourceOptions } = this.config
const dataSource = new DataSource(datasourceOptions)
await dataSource.initialize()
return dataSource
}

private async setup(dataSource: DataSource): Promise<void> {
if (this.isSetup) {
return
}

try {
const appDataSource = await this.datasource.initialize()

this.queryRunner = appDataSource.createQueryRunner()
await this.queryRunner.manager.query(`
const queryRunner = dataSource.createQueryRunner()
await queryRunner.manager.query(`
CREATE TABLE IF NOT EXISTS ${this.tableName} (
thread_id TEXT NOT NULL,
checkpoint_id TEXT NOT NULL,
parent_id TEXT,
checkpoint BLOB,
metadata BLOB,
PRIMARY KEY (thread_id, checkpoint_id));`)
await queryRunner.release()
} catch (error) {
console.error(`Error creating ${this.tableName} table`, error)
throw new Error(`Error creating ${this.tableName} table`)
Expand All @@ -53,16 +51,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
}

async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
await this.setup()
const dataSource = await this.getDataSource()
await this.setup(dataSource)

const thread_id = config.configurable?.thread_id || this.threadId
const checkpoint_id = config.configurable?.checkpoint_id

if (checkpoint_id) {
try {
const queryRunner = dataSource.createQueryRunner()
const keys = [thread_id, checkpoint_id]
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`

const rows = await this.queryRunner.manager.query(sql, [...keys])
const rows = await queryRunner.manager.query(sql, [...keys])
await queryRunner.release()

if (rows && rows.length > 0) {
return {
Expand All @@ -82,39 +84,53 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} catch (error) {
console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`)
} finally {
await dataSource.destroy()
}
} else {
const keys = [thread_id]
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
try {
const queryRunner = dataSource.createQueryRunner()
const keys = [thread_id]
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`

const rows = await this.queryRunner.manager.query(sql, [...keys])
const rows = await queryRunner.manager.query(sql, [...keys])
await queryRunner.release()

if (rows && rows.length > 0) {
return {
config: {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].checkpoint_id
}
},
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
parentConfig: rows[0].parent_id
? {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].parent_id
if (rows && rows.length > 0) {
return {
config: {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].checkpoint_id
}
},
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
parentConfig: rows[0].parent_id
? {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].parent_id
}
}
}
: undefined
: undefined
}
}
} catch (error) {
console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`)
} finally {
await dataSource.destroy()
}
}
return undefined
}

async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
await this.setup()
const dataSource = await this.getDataSource()
await this.setup(dataSource)

const queryRunner = dataSource.createQueryRunner()
const thread_id = config.configurable?.thread_id || this.threadId
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
before ? 'AND checkpoint_id < ?' : ''
Expand All @@ -125,7 +141,8 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean)

try {
const rows = await this.queryRunner.manager.query(sql, [...args])
const rows = await queryRunner.manager.query(sql, [...args])
await queryRunner.release()

if (rows && rows.length > 0) {
for (const row of rows) {
Expand All @@ -152,13 +169,18 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
} catch (error) {
console.error(`Error listing ${this.tableName}`, error)
throw new Error(`Error listing ${this.tableName}`)
} finally {
await dataSource.destroy()
}
}

async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
await this.setup()
const dataSource = await this.getDataSource()
await this.setup(dataSource)

if (!config.configurable?.checkpoint_id) return {}
try {
const queryRunner = dataSource.createQueryRunner()
const row = [
config.configurable?.thread_id || this.threadId,
checkpoint.id,
Expand All @@ -169,10 +191,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (

const query = `INSERT OR REPLACE INTO ${this.tableName} (thread_id, checkpoint_id, parent_id, checkpoint, metadata) VALUES (?, ?, ?, ?, ?)`

await this.queryRunner.manager.query(query, row)
await queryRunner.manager.query(query, row)
await queryRunner.release()
} catch (error) {
console.error('Error saving checkpoint', error)
throw new Error('Error saving checkpoint')
} finally {
await dataSource.destroy()
}

return {
Expand All @@ -187,13 +212,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
if (!threadId) {
return
}
await this.setup()

const dataSource = await this.getDataSource()
await this.setup(dataSource)

const query = `DELETE FROM "${this.tableName}" WHERE thread_id = ?;`

try {
await this.queryRunner.manager.query(query, [threadId])
const queryRunner = dataSource.createQueryRunner()
await queryRunner.manager.query(query, [threadId])
await queryRunner.release()
} catch (error) {
console.error(`Error deleting thread_id ${threadId}`, error)
} finally {
await dataSource.destroy()
}
}

Expand Down

0 comments on commit 7d1234a

Please sign in to comment.