From 69702b6559ff1ac80de98ee28c16df789eb98c98 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Tue, 8 Oct 2024 11:25:12 +0200 Subject: [PATCH] feat(memory): rename SlidingWindowMemory, add removalSelector functionality Signed-off-by: Tomas Dvorak --- src/internals/types.ts | 4 ++ src/memory/base.ts | 5 +- src/memory/slidingMemory.test.ts | 86 ++++++++++++++++++++++++++++ src/memory/slidingMemory.ts | 94 +++++++++++++++++++++++++++++++ src/memory/slidingWindowMemory.ts | 55 ------------------ 5 files changed, 187 insertions(+), 57 deletions(-) create mode 100644 src/memory/slidingMemory.test.ts create mode 100644 src/memory/slidingMemory.ts delete mode 100644 src/memory/slidingWindowMemory.ts diff --git a/src/internals/types.ts b/src/internals/types.ts index 6370795..58d8eaf 100644 --- a/src/internals/types.ts +++ b/src/internals/types.ts @@ -64,6 +64,10 @@ export type RequiredAll = { [P in keyof T]-?: NonNullable; }; +export type RequiredNested = { + [P in keyof T]-?: Required; +}; + export type OmitEmpty = OmitType; export type NonEmptyArray = [T, ...T[]]; export type Unwrap = T extends (infer X)[] ? X : T; diff --git a/src/memory/base.ts b/src/memory/base.ts index 66b8df8..83dcbb1 100644 --- a/src/memory/base.ts +++ b/src/memory/base.ts @@ -15,15 +15,16 @@ */ import { BaseMessage } from "@/llms/primitives/message.js"; -import { FrameworkError } from "@/errors.js"; +import { FrameworkError, FrameworkErrorOptions } from "@/errors.js"; import { Serializable } from "@/internals/serializable.js"; export class MemoryError extends FrameworkError {} export class MemoryFatalError extends MemoryError { - constructor(message: string, errors?: Error[]) { + constructor(message: string, errors?: Error[], options?: FrameworkErrorOptions) { super(message, errors, { isFatal: true, isRetryable: false, + ...options, }); } } diff --git a/src/memory/slidingMemory.test.ts b/src/memory/slidingMemory.test.ts new file mode 100644 index 0000000..9c61ea6 --- /dev/null +++ b/src/memory/slidingMemory.test.ts @@ -0,0 +1,86 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { BaseMessage, Role } from "@/llms/primitives/message.js"; +import { verifyDeserialization } from "@tests/e2e/utils.js"; +import { SlidingMemory } from "@/memory/slidingMemory.js"; + +describe("Sliding Memory", () => { + it("Removes old messages", async () => { + const instance = new SlidingMemory({ + size: 2, + }); + await instance.addMany(["A", "B", "C"].map((text) => BaseMessage.of({ role: "user", text }))); + expect(Array.from(instance).map((msg) => msg.text)).toStrictEqual(["B", "C"]); + }); + + it("Removes messages by removalSelector", async () => { + const instance = new SlidingMemory({ + size: 2, + handlers: { + removalSelector: (messages) => messages.find((msg) => msg.role !== "system")!, + }, + }); + await instance.add(BaseMessage.of({ role: "system", text: "You are a helpful assistant." })); + await instance.addMany(["A", "B", "C"].map((text) => BaseMessage.of({ role: "user", text }))); + expect(Array.from(instance).map((msg) => msg.text)).toStrictEqual([ + "You are a helpful assistant.", + "C", + ]); + }); + + it("Removes multiple messages by removalSelector", async () => { + const instance = new SlidingMemory({ + size: 5, + handlers: { + removalSelector: (messages) => { + const index = messages.findIndex((msg, index, arr) => { + const next = arr[index + 1]; + return ( + (msg.role === "user" && next?.role === "assistant") || + (msg.role === "assistant" && next?.role === "user") + ); + }); + if (index === -1) { + return messages.find((msg) => msg.role === "user")!; + } + return [messages[index], messages[index + 1]]; + }, + }, + }); + await instance.addMany( + ["user", "assistant", "user", "assistant", "user", "user"].map((role, i) => + BaseMessage.of({ role, text: `${i + 1}` }), + ), + ); + expect(Array.from(instance).map((msg) => msg.text)).toStrictEqual(["3", "4", "5", "6"]); + }); + + it("Serializes", async () => { + const instance = new SlidingMemory({ + size: 5, + }); + await instance.add( + BaseMessage.of({ + text: "Hello!", + role: Role.USER, + }), + ); + const serialized = instance.serialize(); + const deserialized = SlidingMemory.fromSerialized(serialized); + verifyDeserialization(instance, deserialized); + }); +}); diff --git a/src/memory/slidingMemory.ts b/src/memory/slidingMemory.ts new file mode 100644 index 0000000..d8b34cd --- /dev/null +++ b/src/memory/slidingMemory.ts @@ -0,0 +1,94 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { BaseMessage } from "@/llms/primitives/message.js"; +import { BaseMemory, MemoryFatalError } from "@/memory/base.js"; +import { shallowCopy } from "@/serializer/utils.js"; +import { filter, forEach, isTruthy, pipe } from "remeda"; +import { castArray } from "@/internals/helpers/array.js"; +import { RequiredNested } from "@/internals/types.js"; + +export interface Handlers { + removalSelector: (messages: BaseMessage[]) => BaseMessage | BaseMessage[]; +} + +export interface SlidingWindowMemoryInput { + size: number; + handlers?: Partial; +} + +export class SlidingMemory extends BaseMemory { + public readonly messages: BaseMessage[] = []; + public readonly config: RequiredNested; + + constructor(config: SlidingWindowMemoryInput) { + super(); + this.config = { + ...config, + handlers: { + removalSelector: + config.handlers?.removalSelector ?? ((messages: BaseMessage[]) => [messages[0]]), + }, + }; + } + + static { + const aliases = ["SlidingWindowMemory"]; + this.register(aliases); + } + + async add(message: BaseMessage) { + const { size, handlers } = this.config; + const isOverflow = () => this.messages.length + 1 > size; + + if (isOverflow()) { + pipe( + this.messages, + handlers.removalSelector, + castArray, + filter(isTruthy), + forEach((message) => { + const index = this.messages.indexOf(message); + if (index === -1) { + throw new MemoryFatalError(`Cannot delete non existing message.`, [], { + context: { message, messages: this.messages }, + }); + } + this.messages.splice(index, 1); + }), + ); + + if (isOverflow()) { + throw new MemoryFatalError( + `Custom memory removalSelector did not return any message. Memory overflow has occurred.`, + ); + } + } + this.messages.push(message); + } + + reset() { + this.messages.length = 0; + } + + createSnapshot() { + return { config: shallowCopy(this.config), messages: shallowCopy(this.messages) }; + } + + loadSnapshot(state: ReturnType) { + Object.assign(this, state); + } +} diff --git a/src/memory/slidingWindowMemory.ts b/src/memory/slidingWindowMemory.ts deleted file mode 100644 index 44c0741..0000000 --- a/src/memory/slidingWindowMemory.ts +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2024 IBM Corp. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { BaseMessage } from "@/llms/primitives/message.js"; -import { BaseMemory } from "@/memory/base.js"; -import { shallowCopy } from "@/serializer/utils.js"; - -export interface SlidingWindowMemoryInput { - size: number; -} - -export class SlidingWindowMemory extends BaseMemory { - public readonly messages: BaseMessage[]; - - constructor(public config: SlidingWindowMemoryInput) { - super(); - this.messages = []; - } - - static { - this.register(); - } - - async add(message: BaseMessage) { - if (this.messages.length + 1 > this.config.size) { - this.messages.shift(); - } - this.messages.push(message); - } - - reset() { - this.messages.length = 0; - } - - createSnapshot() { - return { config: shallowCopy(this.config), messages: shallowCopy(this.messages) }; - } - - loadSnapshot(state: ReturnType) { - Object.assign(this, state); - } -}