-
-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(memory): Add support for ConversationTokenBufferMemory (#26)
Co-authored-by: David Miguel <me@davidmiguel.com>
- Loading branch information
1 parent
0be06e0
commit 8113d1c
Showing
5 changed files
with
231 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import '../model_io/chat_models/models/models.dart'; | ||
import '../model_io/chat_models/utils.dart'; | ||
import '../model_io/language_models/language_models.dart'; | ||
import '../model_io/prompts/prompts.dart'; | ||
import 'buffer_window.dart'; | ||
import 'chat.dart'; | ||
import 'models/models.dart'; | ||
import 'stores/message/in_memory.dart'; | ||
|
||
/// {@template conversation_token_buffer_memory} | ||
/// Rolling buffer for storing a conversation and then retrieving the messages | ||
/// at a later time. | ||
/// | ||
/// It uses token length (rather than number of interactions like | ||
/// [ConversationBufferWindowMemory]) to determine when to flush old | ||
/// interactions from the buffer. This allows it to keep more context while | ||
/// staying under a max token limit. | ||
/// | ||
/// It uses [ChatMessageHistory] as in-memory storage by default. | ||
/// | ||
/// Example: | ||
/// ```dart | ||
/// final memory = ConversationTokenBufferMemory(llm: OpenAI(apiKey: '...')); | ||
/// await memory.saveContext({'foo': 'bar'}, {'bar': 'foo'}); | ||
/// final res = await memory.loadMemoryVariables(); | ||
/// // {'history': 'Human: bar\nAI: foo'} | ||
/// ``` | ||
/// {@endtemplate} | ||
final class ConversationTokenBufferMemory< | ||
LLMInput extends Object, | ||
LLMOptions extends LanguageModelOptions, | ||
LLMOutput extends Object> extends BaseChatMemory { | ||
/// {@macro conversation_token_buffer_memory} | ||
ConversationTokenBufferMemory({ | ||
super.chatHistory, | ||
super.inputKey, | ||
super.outputKey, | ||
super.returnMessages = false, | ||
required this.llm, | ||
this.humanPrefix = 'Human', | ||
this.aiPrefix = 'AI', | ||
this.memoryKey = 'history', | ||
this.maxTokenLimit = 2000, | ||
}); | ||
|
||
/// Language model to use for counting tokens. | ||
final BaseLanguageModel<LLMInput, LLMOptions, LLMOutput> llm; | ||
|
||
/// The prefix to use for human messages. | ||
final String humanPrefix; | ||
|
||
/// The prefix to use for AI messages. | ||
final String aiPrefix; | ||
|
||
/// The memory key to use for the chat history. | ||
final String memoryKey; | ||
|
||
/// Max number of tokens to use. | ||
final int maxTokenLimit; | ||
|
||
@override | ||
Set<String> get memoryKeys => {memoryKey}; | ||
|
||
@override | ||
Future<MemoryVariables> loadMemoryVariables([ | ||
final MemoryInputValues values = const {}, | ||
]) async { | ||
final messages = await chatHistory.getChatMessages(); | ||
if (returnMessages) { | ||
return {memoryKey: messages}; | ||
} | ||
return { | ||
memoryKey: messages.toBufferString( | ||
humanPrefix: humanPrefix, | ||
aiPrefix: aiPrefix, | ||
), | ||
}; | ||
} | ||
|
||
@override | ||
Future<void> saveContext({ | ||
required final MemoryInputValues inputValues, | ||
required final MemoryOutputValues outputValues, | ||
}) async { | ||
await super.saveContext( | ||
inputValues: inputValues, | ||
outputValues: outputValues, | ||
); | ||
List<ChatMessage> buffer = await chatHistory.getChatMessages(); | ||
int currentBufferLength = await llm.countTokens(PromptValue.chat(buffer)); | ||
// Prune buffer if it exceeds max token limit | ||
if (currentBufferLength > maxTokenLimit) { | ||
while (currentBufferLength > maxTokenLimit) { | ||
await chatHistory.removeFirst(); | ||
buffer = await chatHistory.getChatMessages(); | ||
currentBufferLength = await llm.countTokens(PromptValue.chat(buffer)); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import 'package:langchain/src/memory/memory.dart'; | ||
import 'package:langchain/src/memory/token_buffer.dart'; | ||
import 'package:langchain/src/model_io/chat_models/chat_models.dart'; | ||
import 'package:langchain/src/model_io/llms/fake.dart'; | ||
import 'package:test/test.dart'; | ||
|
||
void main() { | ||
group('ConversationTokenBufferMemory tests', () { | ||
test('Test buffer memory', () async { | ||
const model = FakeEchoLLM(); | ||
final memory = ConversationTokenBufferMemory(llm: model); | ||
final result1 = await memory.loadMemoryVariables(); | ||
expect(result1, {'history': ''}); | ||
|
||
await memory.saveContext( | ||
inputValues: {'foo': 'bar'}, | ||
outputValues: {'bar': 'foo'}, | ||
); | ||
const expectedString = 'Human: bar\nAI: foo'; | ||
final result2 = await memory.loadMemoryVariables(); | ||
expect(result2, {'history': expectedString}); | ||
}); | ||
|
||
test('Test buffer memory return messages', () async { | ||
const model = FakeEchoLLM(); | ||
final memory = ConversationTokenBufferMemory( | ||
llm: model, | ||
returnMessages: true, | ||
maxTokenLimit: 4, | ||
); | ||
final result1 = await memory.loadMemoryVariables(); | ||
expect(result1, {'history': <ChatMessage>[]}); | ||
|
||
await memory.saveContext( | ||
inputValues: {'foo': 'bar'}, | ||
outputValues: {'bar': 'foo'}, | ||
); | ||
final expectedResult = [ | ||
ChatMessage.human('bar'), | ||
ChatMessage.ai('foo'), | ||
]; | ||
final result2 = await memory.loadMemoryVariables(); | ||
expect(result2, {'history': expectedResult}); | ||
|
||
await memory.saveContext( | ||
inputValues: {'foo': 'bar1'}, | ||
outputValues: {'bar': 'foo1'}, | ||
); | ||
|
||
final expectedResult2 = [ | ||
ChatMessage.ai('foo'), | ||
ChatMessage.human('bar1'), | ||
ChatMessage.ai('foo1'), | ||
]; | ||
final result3 = await memory.loadMemoryVariables(); | ||
expect(result3, {'history': expectedResult2}); | ||
}); | ||
|
||
test('Test buffer memory with pre-loaded history', () async { | ||
final pastMessages = [ | ||
ChatMessage.human("My name's Jonas"), | ||
ChatMessage.ai('Nice to meet you, Jonas!'), | ||
]; | ||
const model = FakeEchoLLM(); | ||
final memory = ConversationTokenBufferMemory( | ||
llm: model, | ||
maxTokenLimit: 3, | ||
returnMessages: true, | ||
chatHistory: ChatMessageHistory(messages: pastMessages), | ||
); | ||
final result = await memory.loadMemoryVariables(); | ||
expect(result, {'history': pastMessages}); | ||
}); | ||
|
||
test('Test clear memory', () async { | ||
final memory = ConversationBufferMemory(); | ||
await memory.saveContext( | ||
inputValues: {'foo': 'bar'}, | ||
outputValues: {'bar': 'foo'}, | ||
); | ||
const expectedString = 'Human: bar\nAI: foo'; | ||
final result1 = await memory.loadMemoryVariables(); | ||
expect(result1, {'history': expectedString}); | ||
|
||
memory.clear(); | ||
final result2 = await memory.loadMemoryVariables(); | ||
expect(result2, {'history': ''}); | ||
}); | ||
}); | ||
} |