-
Notifications
You must be signed in to change notification settings - Fork 0
/
ai.py
159 lines (116 loc) · 4.66 KB
/
ai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Import libraries
from langchain.vectorstores.cassandra import Cassandra
from langchain.indexes.vectorstore import VectorStoreIndexWrapper
from langchain.llms import OpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from datasets import load_dataset, Dataset
import tiktoken
import sqlite3
import pandas as pd
import threading
import time
# Set up env variables
import os
from dotenv import load_dotenv
load_dotenv() # take environment variables from .env.
ASTRA_DB_BUNDLE_PATH =os.getenv("ASTRA_DB_BUNDLE_PATH")
ASTRA_DB_TOKEN = os.getenv("TOKEN")
ASTRA_DB_CLIENT_ID = os.getenv("CLIENT_ID")
ASTRA_DB_CLIENT_SECRET = os.getenv("SECRET")
ASTRA_DB_KEYSPACE = os.getenv("KEYSPACE")
OPENAI_KEY= os.getenv("OPENAI_KEY")
#define function for getting tokens from a string
def num_tokens_from_string(string: str, encoding_name: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
# Import listing data
def get_listings():
"""Returns a list of all listings."""
# Connect to the database.
db = sqlite3.connect("db/listings.db")
cursor = db.cursor()
# Get all listings from the database.
cursor.execute("SELECT * FROM listings")
# Create a list of all listings.
listings = []
for row in cursor.fetchall():
listings.append(row)
# Close the connection to the database.
db.close()
# Return the list of listings.
return listings
# Config Astra
cloud_config = {
"secure_connect_bundle": ASTRA_DB_BUNDLE_PATH
}
auth_provider = PlainTextAuthProvider(ASTRA_DB_CLIENT_ID, ASTRA_DB_CLIENT_SECRET)
cluster = Cluster(cloud=cloud_config, auth_provider=auth_provider)
astra_session = cluster.connect()
llm = OpenAI(openai_api_key=OPENAI_KEY)
myEmbedding = OpenAIEmbeddings(openai_api_key=OPENAI_KEY)
print("Configuration complete")
# Create Cassandra Store and table if it doesn't exist
listingCassandraStore = Cassandra(
embedding=myEmbedding,
session=astra_session,
keyspace=ASTRA_DB_KEYSPACE,
table_name="listings"
)
print("Cassandra Store created")
# retrieve listings from sqlite
items = get_listings()
print("Listings fetched")
# create dataframe and clean it up
df = pd.DataFrame(items, columns=["id", "title","price", "link"])
df["price"] = df["price"].apply(lambda p: int(p.replace('$', '').replace(',', '')))
df = df[["title", "price", "link"]]
df = df[df["price"] != 0]
# sort decending price
df = df.sort_values("price", ascending=False)
print("Listings cleaned")
# Set Size of chunks to publish. Larger chunks make embedding much faster
chunkSize = 100
# OpenAI will limit tokens per minute, chunking will not help you here and you may have to do at separate times
# todo: replace notebook with a script that counts token use and cools down
sizeToEmbed = len(df)
startIndex = 0
currentTokenCount = 0
tokenLimit = 1000000
textListings = []
for i, row in df.iterrows():
textListings.append(row["title"] + ": $"+str(row["price"]) +"\n")
print("Listings converted to text successfully")
print("Beginning embedding with chunk of size: ",str(chunkSize), " At: ", str(startIndex), " Total Size: ", str(sizeToEmbed),"/",str(len(df)))
for i in range(startIndex, sizeToEmbed, chunkSize):
i_end = (i+chunkSize) % (sizeToEmbed-1)
chunk = textListings[i:i_end]
chunk_flat = ''.join(chunk)
currentTokenCount += num_tokens_from_string(chunk_flat,"cl100k_base")
# if the tokens overflow wait 60 seconds after all pushing to db is complete
if currentTokenCount > tokenLimit:
print("Hit token limit: ", currentTokenCount, " tokens")
print(" Total Size: ", str(sizeToEmbed),"/",str(len(df)))
# task.join()
time.sleep(60)
print("Waiting complete")
currentTokenCount = 0
# thread the publishing of each chunk
print("pushing chunk...")
val = listingCassandraStore.add_documents(documents=[Document(page_content=chunk_flat)])
print("Added: ", val, " documents")
# task = threading.Thread(target=, args=([Document(page_content=chunk_flat)]))
# task.start()
print(f"\n Embedded {sizeToEmbed}/{len(df)} listings")
# Preform a query
# vectorIndex = VectorStoreIndexWrapper(vectorstore=listingCassandraStore)
# query = "What is a normal price for a Mercedes e350?"
# answer = vectorIndex.query(question=query, llm=llm).strip()
# print(answer)
# print("Docs by relevance")
# for doc, score in listingCassandraStore.similarity_search_with_score(query, k=4):
# print("Score:\t",score,"\n",doc)