-
Notifications
You must be signed in to change notification settings - Fork 0
/
chain.py
89 lines (72 loc) · 2.52 KB
/
chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import MongoDBAtlasVectorSearch
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import (
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
)
from pymongo import MongoClient
# Set DB
if os.environ.get("MONGO_URI", None) is None:
raise Exception("Missing `MONGO_URI` environment variable.")
MONGO_URI = os.environ["MONGO_URI"]
DB_NAME = "database-name"
COLLECTION_NAME = "collection-name"
ATLAS_VECTOR_SEARCH_INDEX_NAME = "vector-index-name"
SEARCH_K_VALUE = 100
POST_FILTER_PIPELINE_LIMIT = 1
OPENAI_MODEL_NAME = "gpt-3.5-turbo-16k-0613"
client = MongoClient(MONGO_URI)
db = client[DB_NAME]
MONGODB_COLLECTION = db[COLLECTION_NAME]
# Read from MongoDB Atlas Vector Search
vectorstore = MongoDBAtlasVectorSearch.from_connection_string(
MONGO_URI,
DB_NAME + "." + COLLECTION_NAME,
OpenAIEmbeddings(disallowed_special=()),
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
)
retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": SEARCH_K_VALUE, "post_filter_pipeline": [{"$limit": POST_FILTER_PIPELINE_LIMIT}, {"$project" : { "_id" : 0 } }]}
)
# RAG prompt
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
# RAG
model = ChatOpenAI(model_name=OPENAI_MODEL_NAME, temperature=0)
chain = (
RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
| prompt
| model
| StrOutputParser()
)
# Add typing for input
class Question(BaseModel):
__root__: str
chain = chain.with_types(input_type=Question)
def _ingest(url: str) -> dict:
loader = PyPDFLoader(url)
data = loader.load()
# Split docs
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
docs = text_splitter.split_documents(data)
# Insert the documents in MongoDB Atlas Vector Search
_ = MongoDBAtlasVectorSearch.from_documents(
documents=docs,
embedding=OpenAIEmbeddings(disallowed_special=()),
collection=MONGODB_COLLECTION,
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
)
return {}
ingest = RunnableLambda(_ingest)