Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: Add support for messages in/messages out for RunnableWithMessageHistory #5517

Merged
merged 2 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions langchain-core/.eslintrc.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ module.exports = {
"new-cap": ["error", { properties: false, capIsNew: false }],
'jest/no-focused-tests': 'error',
"arrow-body-style": 0,
"prefer-destructuring": 0,
},
overrides: [
{
Expand Down
14 changes: 14 additions & 0 deletions langchain-core/src/chat_history.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@ export abstract class BaseChatMessageHistory extends Serializable {

public abstract addAIChatMessage(message: string): Promise<void>;

/**
* Add a list of messages.
*
* Implementations should override this method to handle bulk addition of messages
* in an efficient manner to avoid unnecessary round-trips to the underlying store.
*
* @param messages - A list of BaseMessage objects to store.
*/
public async addMessages(messages: BaseMessage[]): Promise<void> {
for (const message of messages) {
await this.addMessage(message);
}
}

public abstract clear(): Promise<void>;
}

Expand Down
129 changes: 88 additions & 41 deletions langchain-core/src/runnables/history.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,44 +149,92 @@ export class RunnableWithMessageHistory<
}

_getInputMessages(
inputValue: string | BaseMessage | Array<BaseMessage>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
inputValue: string | BaseMessage | Array<BaseMessage> | Record<string, any>
): Array<BaseMessage> {
if (typeof inputValue === "string") {
return [new HumanMessage(inputValue)];
} else if (Array.isArray(inputValue)) {
return inputValue;
let parsedInputValue;
if (
typeof inputValue === "object" &&
!Array.isArray(inputValue) &&
!isBaseMessage(inputValue)
) {
let key;
if (this.inputMessagesKey) {
key = this.inputMessagesKey;
} else if (Object.keys(inputValue).length === 1) {
key = Object.keys(inputValue)[0];
} else {
key = "input";
}
if (Array.isArray(inputValue[key]) && Array.isArray(inputValue[key][0])) {
parsedInputValue = inputValue[key][0];
} else {
parsedInputValue = inputValue[key];
}
} else {
return [inputValue];
parsedInputValue = inputValue;
}
if (typeof parsedInputValue === "string") {
return [new HumanMessage(parsedInputValue)];
} else if (Array.isArray(parsedInputValue)) {
return parsedInputValue;
} else if (isBaseMessage(parsedInputValue)) {
return [parsedInputValue];
} else {
throw new Error(
`Expected a string, BaseMessage, or array of BaseMessages.\nGot ${JSON.stringify(
parsedInputValue,
null,
2
)}`
);
}
}

_getOutputMessages(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
outputValue: string | BaseMessage | Array<BaseMessage> | Record<string, any>
): Array<BaseMessage> {
let newOutputValue = outputValue;
let parsedOutputValue;
if (
!Array.isArray(outputValue) &&
!isBaseMessage(outputValue) &&
typeof outputValue !== "string"
) {
newOutputValue = outputValue[this.outputMessagesKey ?? "output"];
let key;
if (this.outputMessagesKey !== undefined) {
key = this.outputMessagesKey;
} else if (Object.keys(outputValue).length === 1) {
key = Object.keys(outputValue)[0];
} else {
key = "output";
}
// If you are wrapping a chat model directly
// The output is actually this weird generations object
if (outputValue.generations !== undefined) {
parsedOutputValue = outputValue.generations[0][0].message;
} else {
parsedOutputValue = outputValue[key];
}
} else {
parsedOutputValue = outputValue;
}

if (typeof newOutputValue === "string") {
return [new AIMessage(newOutputValue)];
} else if (Array.isArray(newOutputValue)) {
return newOutputValue;
} else if (isBaseMessage(newOutputValue)) {
return [newOutputValue];
if (typeof parsedOutputValue === "string") {
return [new AIMessage(parsedOutputValue)];
} else if (Array.isArray(parsedOutputValue)) {
return parsedOutputValue;
} else if (isBaseMessage(parsedOutputValue)) {
return [parsedOutputValue];
} else {
throw new Error(
`Expected a string, BaseMessage, or array of BaseMessages. Received: ${JSON.stringify(
parsedOutputValue,
null,
2
)}`
);
}
throw new Error(
`Expected a string, BaseMessage, or array of BaseMessages. Received: ${JSON.stringify(
newOutputValue,
null,
2
)}`
);
}

async _enterHistory(
Expand All @@ -195,29 +243,31 @@ export class RunnableWithMessageHistory<
kwargs?: { config?: RunnableConfig }
): Promise<BaseMessage[]> {
const history = kwargs?.config?.configurable?.messageHistory;

if (this.historyMessagesKey) {
return history.getMessages();
const messages = await history.getMessages();
if (this.historyMessagesKey === undefined) {
return messages.concat(this._getInputMessages(input));
}

const inputVal =
input ||
(this.inputMessagesKey ? input[this.inputMessagesKey] : undefined);
const historyMessages = history ? await history.getMessages() : [];
const returnType = [
...historyMessages,
...this._getInputMessages(inputVal),
];
return returnType;
return messages;
}

async _exitHistory(run: Run, config: RunnableConfig): Promise<void> {
const history = config.configurable?.messageHistory;

// Get input messages
const { inputs } = run;
const inputValue = inputs[this.inputMessagesKey ?? "input"];
const inputMessages = this._getInputMessages(inputValue);
let inputs;
// Chat model inputs are nested arrays
if (Array.isArray(run.inputs) && Array.isArray(run.inputs[0])) {
inputs = run.inputs[0];
} else {
inputs = run.inputs;
}
let inputMessages = this._getInputMessages(inputs);
// If historic messages were prepended to the input messages, remove them to
// avoid adding duplicate messages to history.
if (this.historyMessagesKey === undefined) {
const existingMessages = await history.getMessages();
inputMessages = inputMessages.slice(existingMessages.length);
}
// Get output messages
const outputValue = run.outputs;
if (!outputValue) {
Expand All @@ -230,10 +280,7 @@ export class RunnableWithMessageHistory<
);
}
const outputMessages = this._getOutputMessages(outputValue);

for await (const message of [...inputMessages, ...outputMessages]) {
await history.addMessage(message);
}
await history.addMessages([...inputMessages, ...outputMessages]);
}

async _mergeConfig(...configs: Array<RunnableConfig | undefined>) {
Expand Down
87 changes: 83 additions & 4 deletions langchain-core/src/runnables/tests/runnable_history.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { BaseMessage, HumanMessage } from "../../messages/index.js";
import {
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
} from "../../messages/index.js";
import { RunnableLambda } from "../base.js";
import { RunnableConfig } from "../config.js";
import { RunnableWithMessageHistory } from "../history.js";
Expand All @@ -10,6 +15,7 @@ import {
FakeChatMessageHistory,
FakeLLM,
FakeListChatMessageHistory,
FakeListChatModel,
FakeStreamingLLM,
} from "../../utils/testing/index.js";
import { ChatPromptTemplate, MessagesPlaceholder } from "../../prompts/chat.js";
Expand Down Expand Up @@ -73,6 +79,79 @@ test("Runnable with message history", async () => {
expect(output).toBe("you said: hello\ngood bye");
});

test("Runnable with message history with a chat model", async () => {
const runnable = new FakeListChatModel({
responses: ["Hello world!"],
});

const getMessageHistory = await getGetSessionHistory();
const withHistory = new RunnableWithMessageHistory({
runnable,
config: {},
getMessageHistory,
});
const config: RunnableConfig = { configurable: { sessionId: "2" } };
const output = await withHistory.invoke([new HumanMessage("hello")], config);
expect(output.content).toBe("Hello world!");
const stream = await withHistory.stream(
[new HumanMessage("good bye")],
config
);
const chunks = [];
for await (const chunk of stream) {
console.log(chunk);
chunks.push(chunk);
}
expect(chunks.map((chunk) => chunk.content).join("")).toEqual("Hello world!");
const sessionHistory = await getMessageHistory("2");
expect(await sessionHistory.getMessages()).toEqual([
new HumanMessage("hello"),
new AIMessage("Hello world!"),
new HumanMessage("good bye"),
new AIMessageChunk("Hello world!"),
]);
});

test("Runnable with message history with a messages in, messages out chain", async () => {
const prompt = ChatPromptTemplate.fromMessages([
["system", "you are a robot"],
["placeholder", "{messages}"],
]);
const model = new FakeListChatModel({
responses: ["So long and thanks for the fish!!"],
});
const runnable = prompt.pipe(model);

const getMessageHistory = await getGetSessionHistory();
const withHistory = new RunnableWithMessageHistory({
runnable,
config: {},
getMessageHistory,
});
const config: RunnableConfig = { configurable: { sessionId: "2" } };
const output = await withHistory.invoke([new HumanMessage("hello")], config);
expect(output.content).toBe("So long and thanks for the fish!!");
const stream = await withHistory.stream(
[new HumanMessage("good bye")],
config
);
const chunks = [];
for await (const chunk of stream) {
console.log(chunk);
chunks.push(chunk);
}
expect(chunks.map((chunk) => chunk.content).join("")).toEqual(
"So long and thanks for the fish!!"
);
const sessionHistory = await getMessageHistory("2");
expect(await sessionHistory.getMessages()).toEqual([
new HumanMessage("hello"),
new AIMessage("So long and thanks for the fish!!"),
new HumanMessage("good bye"),
new AIMessageChunk("So long and thanks for the fish!!"),
]);
});

test("Runnable with message history work with chat list memory", async () => {
const runnable = new RunnableLambda({
func: (messages: BaseMessage[]) =>
Expand All @@ -88,7 +167,7 @@ test("Runnable with message history work with chat list memory", async () => {
config: {},
getMessageHistory: getListMessageHistory,
});
const config: RunnableConfig = { configurable: { sessionId: "1" } };
const config: RunnableConfig = { configurable: { sessionId: "3" } };
let output = await withHistory.invoke([new HumanMessage("hello")], config);
expect(output).toBe("you said: hello");
output = await withHistory.invoke([new HumanMessage("good bye")], config);
Expand All @@ -112,7 +191,7 @@ test("Runnable with message history and RunnableSequence", async () => {
inputMessagesKey: "input",
historyMessagesKey: "history",
});
const config: RunnableConfig = { configurable: { sessionId: "1" } };
const config: RunnableConfig = { configurable: { sessionId: "4" } };
let output = await withHistory.invoke({ input: "hello" }, config);
expect(output).toBe("AI: You are a helpful assistant\nHuman: hello");
output = await withHistory.invoke({ input: "good bye" }, config);
Expand Down Expand Up @@ -140,7 +219,7 @@ test("Runnable with message history should stream through", async () => {
inputMessagesKey: "input",
historyMessagesKey: "history",
}).pipe(new StringOutputParser());
const config: RunnableConfig = { configurable: { sessionId: "1" } };
const config: RunnableConfig = { configurable: { sessionId: "5" } };
const stream = await withHistory.stream({ input: "hello" }, config);
const chunks = [];
for await (const chunk of stream) {
Expand Down
Loading