-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
108 lines (85 loc) · 3.93 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import chromadb
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
from llm import utils as llm_utils
from llm import llm_config
from database import database_utils
from youtube_utils import get_video_ids, get_video_transcript
from llm import chroma_embedding_wrapper
from utils import utils
from utils import cli_messages
from concurrent.futures import ThreadPoolExecutor, as_completed
def process_video(video_data, collection, youtuber):
full_text = f"{video_data.video_transcript}"
chunks = utils.chunk_text(full_text)
for i, chunk in enumerate(chunks):
chunk_id = f"{video_data.video_id}|{i}" # natural key
collection.add(
ids=[chunk_id],
documents=[chunk],
metadatas=[{"channel_name": youtuber, "title": video_data.video_title}]
)
return video_data.video_id
if __name__ == '__main__':
"""
This is the main flow of the script.
What you essentially see happen in the CLI and the main application logic is all here.
"""
# Print the greeting message
cli_messages.print_greeting_message()
youtuber = input("Enter the name of the youtuber you want to talk to: ").strip()
# Make the SQLITE table connection
conn = database_utils.create_connection(db_name = './database/talk_to_youtuber_db.sqlite')
# create the videos table
database_utils.create_table(conn)
# Start downloading the video ids
get_video_ids.get_videos(youtuber, conn)
# Write in transcripts
get_video_transcript.get_transcripts_and_add_to_db(youtuber, conn)
# Download all their transcripts
all_videos = database_utils.get_videos_by_channel(conn, youtuber)
# Start chroma client
# Create a new chroma/directory to from chroma data
client = chromadb.PersistentClient(path="chroma/")
# Chroma collection
collection = client.get_or_create_collection(name="video_transcript_embeddings",
metadata={"hnsw:space": "cosine"},
embedding_function=chroma_embedding_wrapper.openai_ef)
# Check if videos exist in the connection
# Get video id from natural key
collection_ids = [id.split('|')[0] for id in collection.get()['ids']]
# Non overlaps
# Keep only videos that are not already in the collection
non_overlapping_videos = [video_data for video_data in all_videos if f"{video_data.video_id}" not in collection_ids]
with ThreadPoolExecutor(max_workers=100) as executor:
futures = []
for video_data in non_overlapping_videos:
future = executor.submit(process_video, video_data, collection, youtuber)
futures.append(future)
# Use tqdm to show progress
for future in tqdm(as_completed(futures), total=len(futures), desc="Embedding and adding to vector DB"):
video_id = future.result()
# LLM
cli_messages.print_llm_message()
messages = llm_config.messages
while True:
user_input = input("User: ")
# RAG Logic
if '??' in user_input:
# Query chroma with users question
matched_document = collection.query(
query_texts=[user_input],
n_results=1,
where={"channel_name": youtuber}
)
results = utils.clean_chroma_query_most_similar_document(matched_document)
result_video = results['id']
result_document = results['document']
result_title = results['title']
cli_messages.print_intercepting_message(result_document, result_video, result_title)
user_input = cli_messages.intercept_string(user_input, result_document)
messages.append({'role':'user','content':user_input})
result = llm_utils.get_openai_chat(messages)
print(f"\nBot: {result}\n")
messages.append({'role':'assistant','content':result})