|
| 1 | +import { |
| 2 | + BaseChain, |
| 3 | + ChainValues, |
| 4 | + SerializedStuffDocumentsChain, |
| 5 | + StuffDocumentsChain, |
| 6 | + SerializedLLMChain, |
| 7 | + loadQAChain, |
| 8 | + LLMChain, |
| 9 | + } from "./index"; |
| 10 | + |
| 11 | + import { PromptTemplate } from "../prompt"; |
| 12 | + |
| 13 | + import { VectorStore } from "../vectorstores/base"; |
| 14 | + import { BaseLLM } from "../llms"; |
| 15 | + |
| 16 | + import { resolveConfigFromFile } from "../util"; |
| 17 | + // eslint-disable-next-line @typescript-eslint/no-explicit-any |
| 18 | + export type LoadValues = Record<string, any>; |
| 19 | + |
| 20 | +const question_generator_template = `Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. |
| 21 | +
|
| 22 | +Chat History: |
| 23 | +{chat_history} |
| 24 | +Follow Up Input: {question} |
| 25 | +Standalone question:`; |
| 26 | +const question_generator_prompt = PromptTemplate.fromTemplate(question_generator_template); |
| 27 | + |
| 28 | +const qa_template = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. |
| 29 | +
|
| 30 | +{context} |
| 31 | +
|
| 32 | +Question: {question} |
| 33 | +Helpful Answer:`; |
| 34 | +const qa_prompt = PromptTemplate.fromTemplate(qa_template); |
| 35 | + |
| 36 | + |
| 37 | + export interface ChatVectorDBQAChainInput { |
| 38 | + vectorstore: VectorStore; |
| 39 | + k: number; |
| 40 | + combineDocumentsChain: StuffDocumentsChain; |
| 41 | + questionGeneratorChain: LLMChain; |
| 42 | + outputKey: string; |
| 43 | + inputKey: string; |
| 44 | + } |
| 45 | + |
| 46 | + export type SerializedChatVectorDBQAChain = { |
| 47 | + _type: "chat-vector-db"; |
| 48 | + k: number; |
| 49 | + combine_documents_chain: SerializedStuffDocumentsChain; |
| 50 | + combine_documents_chain_path?: string; |
| 51 | + question_generator: SerializedLLMChain; |
| 52 | + }; |
| 53 | + |
| 54 | + export class ChatVectorDBQAChain extends BaseChain implements ChatVectorDBQAChainInput { |
| 55 | + k = 4; |
| 56 | + |
| 57 | + inputKey = "question"; |
| 58 | + |
| 59 | + chatHistoryKey = "chat_history"; |
| 60 | + |
| 61 | + outputKey = "result"; |
| 62 | + |
| 63 | + vectorstore: VectorStore; |
| 64 | + |
| 65 | + combineDocumentsChain: StuffDocumentsChain; |
| 66 | + |
| 67 | + questionGeneratorChain: LLMChain; |
| 68 | + |
| 69 | + constructor(fields: { |
| 70 | + vectorstore: VectorStore; |
| 71 | + combineDocumentsChain: StuffDocumentsChain; |
| 72 | + questionGeneratorChain: LLMChain; |
| 73 | + inputKey?: string; |
| 74 | + outputKey?: string; |
| 75 | + k?: number; |
| 76 | + }) { |
| 77 | + super(); |
| 78 | + this.vectorstore = fields.vectorstore; |
| 79 | + this.combineDocumentsChain = fields.combineDocumentsChain; |
| 80 | + this.questionGeneratorChain = fields.questionGeneratorChain; |
| 81 | + this.inputKey = fields.inputKey ?? this.inputKey; |
| 82 | + this.outputKey = fields.outputKey ?? this.outputKey; |
| 83 | + this.k = fields.k ?? this.k; |
| 84 | + } |
| 85 | + |
| 86 | + async _call(values: ChainValues): Promise<ChainValues> { |
| 87 | + if (!(this.inputKey in values)) { |
| 88 | + throw new Error(`Question key ${this.inputKey} not found.`); |
| 89 | + } |
| 90 | + if (!(this.chatHistoryKey in values)) { |
| 91 | + throw new Error(`chat history key ${this.inputKey} not found.`); |
| 92 | + } |
| 93 | + const question: string = values[this.inputKey]; |
| 94 | + const chatHistory: string = values[this.chatHistoryKey]; |
| 95 | + let newQuestion = question; |
| 96 | + if (chatHistory.length > 0){ |
| 97 | + const result = await this.questionGeneratorChain.call({question, chat_history: chatHistory}); |
| 98 | + const keys = Object.keys(result); |
| 99 | + if (keys.length === 1) { |
| 100 | + newQuestion = result[keys[0]]; |
| 101 | + } else { |
| 102 | + throw new Error( |
| 103 | + "Return from llm chain has multiple values, only single values supported." |
| 104 | + ); |
| 105 | + |
| 106 | + } |
| 107 | + } |
| 108 | + const docs = await this.vectorstore.similaritySearch(newQuestion, this.k); |
| 109 | + const inputs = { question, input_documents: docs, chat_history: chatHistory}; |
| 110 | + const result = await this.combineDocumentsChain.call(inputs); |
| 111 | + return result; |
| 112 | + } |
| 113 | + |
| 114 | + _chainType() { |
| 115 | + return "chat-vector-db" as const; |
| 116 | + } |
| 117 | + |
| 118 | + static async deserialize( |
| 119 | + data: SerializedChatVectorDBQAChain, |
| 120 | + values: LoadValues |
| 121 | + ) { |
| 122 | + if (!("vectorstore" in values)) { |
| 123 | + throw new Error( |
| 124 | + `Need to pass in a vectorstore to deserialize VectorDBQAChain` |
| 125 | + ); |
| 126 | + } |
| 127 | + const { vectorstore } = values; |
| 128 | + const serializedCombineDocumentsChain = resolveConfigFromFile< |
| 129 | + "combine_documents_chain", |
| 130 | + SerializedStuffDocumentsChain |
| 131 | + >("combine_documents_chain", data); |
| 132 | + const serializedQuestionGeneratorChain = resolveConfigFromFile< |
| 133 | + "question_generator", |
| 134 | + SerializedLLMChain |
| 135 | + >("question_generator", data); |
| 136 | + |
| 137 | + return new ChatVectorDBQAChain({ |
| 138 | + combineDocumentsChain: await StuffDocumentsChain.deserialize( |
| 139 | + serializedCombineDocumentsChain |
| 140 | + ), |
| 141 | + questionGeneratorChain: await LLMChain.deserialize( |
| 142 | + serializedQuestionGeneratorChain |
| 143 | + ), |
| 144 | + k: data.k, |
| 145 | + vectorstore, |
| 146 | + }); |
| 147 | + } |
| 148 | + |
| 149 | + serialize(): SerializedChatVectorDBQAChain { |
| 150 | + return { |
| 151 | + _type: this._chainType(), |
| 152 | + combine_documents_chain: this.combineDocumentsChain.serialize(), |
| 153 | + question_generator: this.questionGeneratorChain.serialize(), |
| 154 | + k: this.k, |
| 155 | + }; |
| 156 | + } |
| 157 | + |
| 158 | + static fromLLM(llm: BaseLLM, vectorstore: VectorStore): ChatVectorDBQAChain { |
| 159 | + const qaChain = loadQAChain(llm, qa_prompt); |
| 160 | + const questionGeneratorChain = new LLMChain({prompt: question_generator_prompt, llm}); |
| 161 | + const instance = new this({ vectorstore, combineDocumentsChain: qaChain, questionGeneratorChain}); |
| 162 | + return instance; |
| 163 | + } |
| 164 | + } |
0 commit comments