Skip to content

Commit 46b1485

Browse files
authored
chat vector db chain (#22)
* cr * cr
1 parent 0de457d commit 46b1485

File tree

13 files changed

+298
-17
lines changed

13 files changed

+298
-17
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Chat Vector DB QA Chain
2+
3+
A Chat Vector DB QA chain takes as input a question and chat history.
4+
It first combines the chat history and the question into a standalone question, then looks up relevant documents from the vector database, and then passes those documents and the question to a question answering chain to return a response.
5+
6+
To create one, you will need a vectorstore, which can be created from embeddings.
7+
8+
Below is an end-to-end example of doing question answering over a recent state of the union address.
9+
10+
```typescript
11+
import { OpenAI } from "langchain/llms";
12+
import { ChatVectorDBQAChain } from "langchain/chains";
13+
import { HNSWLib } from "langchain/vectorstores";
14+
import { OpenAIEmbeddings } from "langchain/embeddings";
15+
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
16+
import * as fs from 'fs';
17+
18+
19+
/* Initialize the LLM to use to answer the question */
20+
const model = new OpenAI({});
21+
/* Load in the file we want to do question answering over */
22+
const text = fs.readFileSync('state_of_the_union.txt','utf8');
23+
/* Split the text into chunks */
24+
const textSplitter = new RecursiveCharacterTextSplitter({chunkSize: 1000});
25+
const docs = textSplitter.createDocuments([text]);
26+
/* Create the vectorstore */
27+
const vectorStore = await HNSWLib.fromDocuments(
28+
docs,
29+
new OpenAIEmbeddings()
30+
);
31+
/* Create the chain */
32+
const chain = ChatVectorDBQAChain.fromLLM(model, vectorStore);
33+
/* Ask it a question */
34+
const question = "What did the president say about Justice Breyer?"
35+
const res = await chain.call({ question: question, chat_history: [] });
36+
console.log(res);
37+
/* Ask it a follow up question */
38+
const chatHistory = question + res["text"]
39+
const followUpRes = await chain.call({ question: "Was that nice?", chat_history: chatHistory });
40+
console.log(followUpRes);
41+
42+
```
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import { OpenAI } from "langchain/llms";
2+
import { ChatVectorDBQAChain } from "langchain/chains";
3+
import { HNSWLib } from "langchain/vectorstores";
4+
import { OpenAIEmbeddings } from "langchain/embeddings";
5+
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
6+
import * as fs from 'fs';
7+
8+
export const run = async () => {
9+
/* Initialize the LLM to use to answer the question */
10+
const model = new OpenAI({});
11+
/* Load in the file we want to do question answering over */
12+
const text = fs.readFileSync('state_of_the_union.txt','utf8');
13+
/* Split the text into chunks */
14+
const textSplitter = new RecursiveCharacterTextSplitter({chunkSize: 1000});
15+
const docs = textSplitter.createDocuments([text]);
16+
/* Create the vectorstore */
17+
const vectorStore = await HNSWLib.fromDocuments(
18+
docs,
19+
new OpenAIEmbeddings()
20+
);
21+
/* Create the chain */
22+
const chain = ChatVectorDBQAChain.fromLLM(model, vectorStore);
23+
/* Ask it a question */
24+
const question = "What did the president say about Justice Breyer?";
25+
const res = await chain.call({ question, chat_history: [] });
26+
console.log(res);
27+
/* Ask it a follow up question */
28+
const chatHistory = question + res.text;
29+
const followUpRes = await chain.call({ question: "Was that nice?", chat_history: chatHistory });
30+
console.log(followUpRes);
31+
};

langchain/agents/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ export { Agent, StaticAgent, staticImplements, AgentInput } from "./agent";
99
export { AgentExecutor } from "./executor";
1010
export { ZeroShotAgent, SerializedZeroShotAgent } from "./mrkl";
1111
export { Tool } from "./tools";
12-
export {initializeAgentExecutor} from "./initialize"
12+
export {initializeAgentExecutor} from "./initialize";
1313

1414
export { loadAgent } from "./load";

langchain/agents/initialize.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { Tool } from "./tools";
22
import { BaseLLM } from "../llms";
33
import { AgentExecutor } from "./executor";
44
import { ZeroShotAgent } from "./mrkl";
5+
56
export const initializeAgentExecutor = async (
67
tools: Tool[],
78
llm: BaseLLM,
@@ -15,7 +16,7 @@ export const initializeAgentExecutor = async (
1516
tools,
1617
returnIntermediateSteps: true,
1718
});
18-
return executor
19+
return executor;
1920
default:
2021
throw new Error("Unknown agent type");
2122
}

langchain/chains/base.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
import { LLMChain, StuffDocumentsChain, VectorDBQAChain } from "./index";
1+
import { LLMChain, StuffDocumentsChain, VectorDBQAChain, ChatVectorDBQAChain } from "./index";
22
import { BaseMemory } from "../memory";
33

44
// eslint-disable-next-line @typescript-eslint/no-explicit-any
55
export type ChainValues = Record<string, any>;
66
// eslint-disable-next-line @typescript-eslint/no-explicit-any
77
export type LoadValues = Record<string, any>;
88

9-
const chainClasses = [LLMChain, StuffDocumentsChain, VectorDBQAChain];
9+
const chainClasses = [LLMChain, StuffDocumentsChain, VectorDBQAChain, ChatVectorDBQAChain];
1010

1111
export type SerializedBaseChain = ReturnType<
1212
InstanceType<(typeof chainClasses)[number]>["serialize"]
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import {
2+
BaseChain,
3+
ChainValues,
4+
SerializedStuffDocumentsChain,
5+
StuffDocumentsChain,
6+
SerializedLLMChain,
7+
loadQAChain,
8+
LLMChain,
9+
} from "./index";
10+
11+
import { PromptTemplate } from "../prompt";
12+
13+
import { VectorStore } from "../vectorstores/base";
14+
import { BaseLLM } from "../llms";
15+
16+
import { resolveConfigFromFile } from "../util";
17+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
18+
export type LoadValues = Record<string, any>;
19+
20+
const question_generator_template = `Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
21+
22+
Chat History:
23+
{chat_history}
24+
Follow Up Input: {question}
25+
Standalone question:`;
26+
const question_generator_prompt = PromptTemplate.fromTemplate(question_generator_template);
27+
28+
const qa_template = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
29+
30+
{context}
31+
32+
Question: {question}
33+
Helpful Answer:`;
34+
const qa_prompt = PromptTemplate.fromTemplate(qa_template);
35+
36+
37+
export interface ChatVectorDBQAChainInput {
38+
vectorstore: VectorStore;
39+
k: number;
40+
combineDocumentsChain: StuffDocumentsChain;
41+
questionGeneratorChain: LLMChain;
42+
outputKey: string;
43+
inputKey: string;
44+
}
45+
46+
export type SerializedChatVectorDBQAChain = {
47+
_type: "chat-vector-db";
48+
k: number;
49+
combine_documents_chain: SerializedStuffDocumentsChain;
50+
combine_documents_chain_path?: string;
51+
question_generator: SerializedLLMChain;
52+
};
53+
54+
export class ChatVectorDBQAChain extends BaseChain implements ChatVectorDBQAChainInput {
55+
k = 4;
56+
57+
inputKey = "question";
58+
59+
chatHistoryKey = "chat_history";
60+
61+
outputKey = "result";
62+
63+
vectorstore: VectorStore;
64+
65+
combineDocumentsChain: StuffDocumentsChain;
66+
67+
questionGeneratorChain: LLMChain;
68+
69+
constructor(fields: {
70+
vectorstore: VectorStore;
71+
combineDocumentsChain: StuffDocumentsChain;
72+
questionGeneratorChain: LLMChain;
73+
inputKey?: string;
74+
outputKey?: string;
75+
k?: number;
76+
}) {
77+
super();
78+
this.vectorstore = fields.vectorstore;
79+
this.combineDocumentsChain = fields.combineDocumentsChain;
80+
this.questionGeneratorChain = fields.questionGeneratorChain;
81+
this.inputKey = fields.inputKey ?? this.inputKey;
82+
this.outputKey = fields.outputKey ?? this.outputKey;
83+
this.k = fields.k ?? this.k;
84+
}
85+
86+
async _call(values: ChainValues): Promise<ChainValues> {
87+
if (!(this.inputKey in values)) {
88+
throw new Error(`Question key ${this.inputKey} not found.`);
89+
}
90+
if (!(this.chatHistoryKey in values)) {
91+
throw new Error(`chat history key ${this.inputKey} not found.`);
92+
}
93+
const question: string = values[this.inputKey];
94+
const chatHistory: string = values[this.chatHistoryKey];
95+
let newQuestion = question;
96+
if (chatHistory.length > 0){
97+
const result = await this.questionGeneratorChain.call({question, chat_history: chatHistory});
98+
const keys = Object.keys(result);
99+
if (keys.length === 1) {
100+
newQuestion = result[keys[0]];
101+
} else {
102+
throw new Error(
103+
"Return from llm chain has multiple values, only single values supported."
104+
);
105+
106+
}
107+
}
108+
const docs = await this.vectorstore.similaritySearch(newQuestion, this.k);
109+
const inputs = { question, input_documents: docs, chat_history: chatHistory};
110+
const result = await this.combineDocumentsChain.call(inputs);
111+
return result;
112+
}
113+
114+
_chainType() {
115+
return "chat-vector-db" as const;
116+
}
117+
118+
static async deserialize(
119+
data: SerializedChatVectorDBQAChain,
120+
values: LoadValues
121+
) {
122+
if (!("vectorstore" in values)) {
123+
throw new Error(
124+
`Need to pass in a vectorstore to deserialize VectorDBQAChain`
125+
);
126+
}
127+
const { vectorstore } = values;
128+
const serializedCombineDocumentsChain = resolveConfigFromFile<
129+
"combine_documents_chain",
130+
SerializedStuffDocumentsChain
131+
>("combine_documents_chain", data);
132+
const serializedQuestionGeneratorChain = resolveConfigFromFile<
133+
"question_generator",
134+
SerializedLLMChain
135+
>("question_generator", data);
136+
137+
return new ChatVectorDBQAChain({
138+
combineDocumentsChain: await StuffDocumentsChain.deserialize(
139+
serializedCombineDocumentsChain
140+
),
141+
questionGeneratorChain: await LLMChain.deserialize(
142+
serializedQuestionGeneratorChain
143+
),
144+
k: data.k,
145+
vectorstore,
146+
});
147+
}
148+
149+
serialize(): SerializedChatVectorDBQAChain {
150+
return {
151+
_type: this._chainType(),
152+
combine_documents_chain: this.combineDocumentsChain.serialize(),
153+
question_generator: this.questionGeneratorChain.serialize(),
154+
k: this.k,
155+
};
156+
}
157+
158+
static fromLLM(llm: BaseLLM, vectorstore: VectorStore): ChatVectorDBQAChain {
159+
const qaChain = loadQAChain(llm, qa_prompt);
160+
const questionGeneratorChain = new LLMChain({prompt: question_generator_prompt, llm});
161+
const instance = new this({ vectorstore, combineDocumentsChain: qaChain, questionGeneratorChain});
162+
return instance;
163+
}
164+
}

langchain/chains/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ export {
44
SerializedStuffDocumentsChain,
55
StuffDocumentsChain,
66
} from "./combine_docs_chain";
7+
export { ChatVectorDBQAChain, SerializedChatVectorDBQAChain} from "./chat_vector_db_chain";
78
export { VectorDBQAChain, SerializedVectorDBQAChain } from "./vector_db_qa";
89
export { loadChain } from "./load";
910
export { loadQAChain } from "./question_answering/load";

langchain/chains/question_answering/load.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import { BaseLLM } from "../../llms";
22
import { LLMChain } from "../llm_chain";
33
import { StuffDocumentsChain } from "../combine_docs_chain";
4-
import { prompt } from "./stuff_prompts";
4+
import { DEFAULT_QA_PROMPT } from "./stuff_prompts";
55

66

7-
export const loadQAChain = (llm: BaseLLM) => {
7+
export const loadQAChain = (llm: BaseLLM, prompt = DEFAULT_QA_PROMPT) => {
88
const llmChain = new LLMChain({ prompt, llm });
99
const chain = new StuffDocumentsChain({llmChain});
1010
return chain;

langchain/chains/question_answering/stuff_prompts.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/* eslint-disable */
22
import { PromptTemplate } from "../../prompt";
33

4-
export const prompt = new PromptTemplate({
4+
export const DEFAULT_QA_PROMPT = new PromptTemplate({
55
template: "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:",
66
inputVariables: ["context", "question"],
77
});
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import { test } from "@jest/globals";
2+
import { OpenAI } from "../../llms/openai";
3+
import { PromptTemplate } from "../../prompt";
4+
import { LLMChain } from "../llm_chain";
5+
import { StuffDocumentsChain } from "../combine_docs_chain";
6+
import { ChatVectorDBQAChain } from "../chat_vector_db_chain";
7+
import { HNSWLib } from "../../vectorstores/hnswlib";
8+
import { OpenAIEmbeddings } from "../../embeddings";
9+
10+
test("Test ChatVectorDBQAChain", async () => {
11+
const model = new OpenAI({});
12+
const prompt = PromptTemplate.fromTemplate("Print {question}, and ignore {chat_history}");
13+
const vectorStore = await HNSWLib.fromTexts(
14+
["Hello world", "Bye bye", "hello nice world", "bye", "hi"],
15+
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }],
16+
new OpenAIEmbeddings()
17+
);
18+
const llmChain = new LLMChain({ prompt, llm: model });
19+
const combineDocsChain = new StuffDocumentsChain({
20+
llmChain,
21+
documentVariableName: "foo",
22+
});
23+
const chain = new ChatVectorDBQAChain({
24+
combineDocumentsChain: combineDocsChain,
25+
vectorstore: vectorStore,
26+
questionGeneratorChain: llmChain,
27+
});
28+
const res = await chain.call({ question: "foo", chat_history: "bar" });
29+
console.log({ res });
30+
});
31+
32+
test("Test ChatVectorDBQAChain from LLM", async () => {
33+
const model = new OpenAI({});
34+
const vectorStore = await HNSWLib.fromTexts(
35+
["Hello world", "Bye bye", "hello nice world", "bye", "hi"],
36+
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }],
37+
new OpenAIEmbeddings()
38+
);
39+
const chain = ChatVectorDBQAChain.fromLLM(model, vectorStore);
40+
const res = await chain.call({ question: "foo", chat_history: "bar" });
41+
console.log({ res });
42+
});

0 commit comments

Comments
 (0)