Skip to content

Commit

Permalink
feat(memory): rename SlidingWindowMemory, add removalSelector functio…
Browse files Browse the repository at this point in the history
…nality

Signed-off-by: Tomas Dvorak <toomas2d@gmail.com>
  • Loading branch information
Tomas2D committed Oct 8, 2024
1 parent 23bba65 commit 69702b6
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 57 deletions.
4 changes: 4 additions & 0 deletions src/internals/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ export type RequiredAll<T> = {
[P in keyof T]-?: NonNullable<T[P]>;
};

export type RequiredNested<T> = {
[P in keyof T]-?: Required<T[P]>;
};

export type OmitEmpty<T> = OmitType<T, never | void>;
export type NonEmptyArray<T> = [T, ...T[]];
export type Unwrap<T> = T extends (infer X)[] ? X : T;
Expand Down
5 changes: 3 additions & 2 deletions src/memory/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
*/

import { BaseMessage } from "@/llms/primitives/message.js";
import { FrameworkError } from "@/errors.js";
import { FrameworkError, FrameworkErrorOptions } from "@/errors.js";
import { Serializable } from "@/internals/serializable.js";

export class MemoryError extends FrameworkError {}
export class MemoryFatalError extends MemoryError {
constructor(message: string, errors?: Error[]) {
constructor(message: string, errors?: Error[], options?: FrameworkErrorOptions) {
super(message, errors, {
isFatal: true,
isRetryable: false,
...options,
});
}
}
Expand Down
86 changes: 86 additions & 0 deletions src/memory/slidingMemory.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { verifyDeserialization } from "@tests/e2e/utils.js";
import { SlidingMemory } from "@/memory/slidingMemory.js";

describe("Sliding Memory", () => {
it("Removes old messages", async () => {
const instance = new SlidingMemory({
size: 2,
});
await instance.addMany(["A", "B", "C"].map((text) => BaseMessage.of({ role: "user", text })));
expect(Array.from(instance).map((msg) => msg.text)).toStrictEqual(["B", "C"]);
});

it("Removes messages by removalSelector", async () => {
const instance = new SlidingMemory({
size: 2,
handlers: {
removalSelector: (messages) => messages.find((msg) => msg.role !== "system")!,
},
});
await instance.add(BaseMessage.of({ role: "system", text: "You are a helpful assistant." }));
await instance.addMany(["A", "B", "C"].map((text) => BaseMessage.of({ role: "user", text })));
expect(Array.from(instance).map((msg) => msg.text)).toStrictEqual([
"You are a helpful assistant.",
"C",
]);
});

it("Removes multiple messages by removalSelector", async () => {
const instance = new SlidingMemory({
size: 5,
handlers: {
removalSelector: (messages) => {
const index = messages.findIndex((msg, index, arr) => {
const next = arr[index + 1];
return (
(msg.role === "user" && next?.role === "assistant") ||
(msg.role === "assistant" && next?.role === "user")
);
});
if (index === -1) {
return messages.find((msg) => msg.role === "user")!;
}
return [messages[index], messages[index + 1]];
},
},
});
await instance.addMany(
["user", "assistant", "user", "assistant", "user", "user"].map((role, i) =>
BaseMessage.of({ role, text: `${i + 1}` }),
),
);
expect(Array.from(instance).map((msg) => msg.text)).toStrictEqual(["3", "4", "5", "6"]);
});

it("Serializes", async () => {
const instance = new SlidingMemory({
size: 5,
});
await instance.add(
BaseMessage.of({
text: "Hello!",
role: Role.USER,
}),
);
const serialized = instance.serialize();
const deserialized = SlidingMemory.fromSerialized(serialized);
verifyDeserialization(instance, deserialized);
});
});
94 changes: 94 additions & 0 deletions src/memory/slidingMemory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { BaseMessage } from "@/llms/primitives/message.js";
import { BaseMemory, MemoryFatalError } from "@/memory/base.js";
import { shallowCopy } from "@/serializer/utils.js";
import { filter, forEach, isTruthy, pipe } from "remeda";
import { castArray } from "@/internals/helpers/array.js";
import { RequiredNested } from "@/internals/types.js";

export interface Handlers {
removalSelector: (messages: BaseMessage[]) => BaseMessage | BaseMessage[];
}

export interface SlidingWindowMemoryInput {
size: number;
handlers?: Partial<Handlers>;
}

export class SlidingMemory extends BaseMemory {
public readonly messages: BaseMessage[] = [];
public readonly config: RequiredNested<SlidingWindowMemoryInput>;

constructor(config: SlidingWindowMemoryInput) {
super();
this.config = {
...config,
handlers: {
removalSelector:
config.handlers?.removalSelector ?? ((messages: BaseMessage[]) => [messages[0]]),
},
};
}

static {
const aliases = ["SlidingWindowMemory"];
this.register(aliases);
}

async add(message: BaseMessage) {
const { size, handlers } = this.config;
const isOverflow = () => this.messages.length + 1 > size;

if (isOverflow()) {
pipe(
this.messages,
handlers.removalSelector,
castArray,
filter(isTruthy),
forEach((message) => {
const index = this.messages.indexOf(message);
if (index === -1) {
throw new MemoryFatalError(`Cannot delete non existing message.`, [], {
context: { message, messages: this.messages },
});
}
this.messages.splice(index, 1);
}),
);

if (isOverflow()) {
throw new MemoryFatalError(
`Custom memory removalSelector did not return any message. Memory overflow has occurred.`,
);
}
}
this.messages.push(message);
}

reset() {
this.messages.length = 0;
}

createSnapshot() {
return { config: shallowCopy(this.config), messages: shallowCopy(this.messages) };
}

loadSnapshot(state: ReturnType<typeof this.createSnapshot>) {
Object.assign(this, state);
}
}
55 changes: 0 additions & 55 deletions src/memory/slidingWindowMemory.ts

This file was deleted.

0 comments on commit 69702b6

Please sign in to comment.