Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error handling, update default filename, create file for shared logic called functions.py #4

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
130 changes: 5 additions & 125 deletions cron.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,17 @@
#!/usr/bin/env python3

import json
import re
import requests
import os
from multiprocessing.pool import ThreadPool

import openai
import srt

import numpy as np
import pandas as pd

from datetime import datetime, timedelta

from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from requests.exceptions import HTTPError
from sklearn.cluster import KMeans

from functions import load_inventory, select_docs, get_summary
from functions import THREAD_COUNT, OUTPUT_FOLDER_NAME

VISEXP = "https://storage.googleapis.com/data.gdeltproject.org/gdeltv3/iatv/visualexplorer"
VICUNA = "http://fc6000.sf.archive.org:8000/v1"
OUTPUT_FOLDER_NAME = "summaries"

LLM_MODELS = {
"OpenAI": "gpt-3.5-turbo",
"Vicuna": "text-embedding-ada-002",
}

IDDTRE = re.compile(r"^.+_(\d{8}_\d{6})")
VICUNA = "http://fc6000.sf.archive.org:8000/v1"

CHANNELS = [
"ESPRESO",
Expand All @@ -47,108 +29,6 @@
DT = (datetime.now() - timedelta(hours=30)).date().strftime("%Y%m%d") # date
LG = "English" # language

THREAD_COUNT = 10

def load_srt(id, lg):
lang = "" if lg == "Original" else ".en"
r = requests.get(f"{VISEXP}/{id}.transcript{lang}.srt")
r.raise_for_status()
return r.content


def load_inventory(ch, dt, lg):
r = requests.get(f"{VISEXP}/{ch}.{dt}.inventory.json")
r.raise_for_status()
return pd.json_normalize(r.json(), record_path="shows").sort_values("start_time", ignore_index=True)


def create_doc(txt, id, start, end):
return Document(page_content=txt, metadata={"id": id, "start": round(start.total_seconds()), "end": round(end.total_seconds())})


def chunk_srt(sr, id, lim=3.0):
docs = []
ln = 0
txt = ""
start = end = timedelta()
for s in srt.parse(sr.decode()):
cl = (s.end - s.start).total_seconds()
if ln + cl > lim:
if txt:
docs.append(create_doc(txt, id, start, end))
ln = cl
txt = s.content
start = s.start
end = s.end
else:
ln += cl
txt += " " + s.content
end = s.end
if txt:
docs.append(create_doc(txt, id, start, end))
return docs


def load_chunks(inventory, lg, ck):
chks = []
for i, r in inventory.iterrows():
try:
sr = load_srt(r.id, lg)
except HTTPError as _:
continue
chks += chunk_srt(sr, r.id, lim=ck)
return chks


def load_vectors(d, llm):
embed = OpenAIEmbeddings(model=LLM_MODELS[llm])
result = embed.embed_query(d.page_content)
return result

def select_docs(dt, ch, lg, lm, ck, ct):
print("loading chunks...")
docs = load_chunks(inventory, lg, ck)
docs_list = [(d,lm) for d in docs]

print("loading vectors...")
with ThreadPool(THREAD_COUNT) as pool:
vectors = pool.starmap(load_vectors, docs_list)

print("number of vectors =", len(vectors))
kmeans = KMeans(n_clusters=ct, random_state=10, n_init=10).fit(vectors)
cent = sorted([np.argmin(np.linalg.norm(vectors - c, axis=1)) for c in kmeans.cluster_centers_])
return [docs[i] for i in cent]


def id_to_time(id, start=0):
dt = IDDTRE.match(id).groups()[0]
return datetime.strptime(dt, "%Y%m%d_%H%M%S") + timedelta(seconds=start)


def get_summary(d, llm):
msg = f"""
```{d}```

Create the most prominent headline from the text enclosed in three backticks (```) above, describe it in a paragraph, assign a category to it, determine whether it is of international interest, determine whether it is an advertisement, and assign the top three keywords in the following JSON format:

{{
"title": "<TITLE>",
"description": "<DESCRIPTION>",
"category": "<CATEGORY>",
"international_interest": true|false,
"advertisement": true|false,
"keywords": ["<KEYWORD1>", "<KEYWORD2>", "<KEYWORD3>"]
}}
"""
res = openai.ChatCompletion.create(
model=LLM_MODELS[llm],
messages=[{"role": "user", "content": msg}]
)
result = json.loads(res.choices[0].message.content.strip())
result = result | d.metadata
result["transcript"] = d.page_content.strip()
return result


if LM == "Vicuna":
openai.api_key = "EMPTY"
Expand All @@ -169,7 +49,7 @@ def get_summary(d, llm):
print(f"Inventory for `{ch}` channel is not available for `{DT[:4]}-{DT[4:6]}-{DT[6:8]}` yet, try selecting another date!", icon="⚠️")

print("loading documents...")
seldocs = select_docs(DT, ch, LG, LM, CK, CT)
seldocs = select_docs(DT, ch, LG, LM, CK, CT, inventory)

print("begin summarizing each document...")

Expand All @@ -178,6 +58,6 @@ def get_summary(d, llm):
summaries = pool.starmap(get_summary, summary_args)

print("writing results...")
with open(f"{OUTPUT_FOLDER_NAME}/{ch}-{DT}-{LM}-{LG}.json", 'w+') as f:
with open(f"{OUTPUT_FOLDER_NAME}/{DT}-{ch}-{LM}-{LG}.json", 'w+') as f:
f.write(json.dumps(summaries, indent=2))
print(f"finished {ch}")
134 changes: 134 additions & 0 deletions functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#!/usr/bin/env python3

import json
import logging
import re
import requests
from multiprocessing.pool import ThreadPool
import os
import time

import openai
import srt

import numpy as np
import pandas as pd

from datetime import datetime, timedelta

from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.schema import Document
from requests.exceptions import HTTPError
from sklearn.cluster import KMeans


THREAD_COUNT = 10
OUTPUT_FOLDER_NAME = "summaries"
VISEXP = "https://storage.googleapis.com/data.gdeltproject.org/gdeltv3/iatv/visualexplorer"
LLM_MODELS = {
"OpenAI": "gpt-3.5-turbo",
"Vicuna": "text-embedding-ada-002",
}


def load_srt(id, lg):
lang = "" if lg == "Original" else ".en"
r = requests.get(f"{VISEXP}/{id}.transcript{lang}.srt")
r.raise_for_status()
return r.content

def load_inventory(ch, dt, lg):
r = requests.get(f"{VISEXP}/{ch}.{dt}.inventory.json")
r.raise_for_status()
return pd.json_normalize(r.json(), record_path="shows").sort_values("start_time", ignore_index=True)

def create_doc(txt, id, start, end):
return Document(page_content=txt, metadata={"id": id, "start": round(start.total_seconds()), "end": round(end.total_seconds())})

def chunk_srt(sr, id, lim=3.0):
docs = []
ln = 0
txt = ""
start = end = timedelta()
for s in srt.parse(sr.decode()):
cl = (s.end - s.start).total_seconds()
if ln + cl > lim:
if txt:
docs.append(create_doc(txt, id, start, end))
ln = cl
txt = s.content
start = s.start
end = s.end
else:
ln += cl
txt += " " + s.content
end = s.end
if txt:
docs.append(create_doc(txt, id, start, end))
return docs

def load_chunks(inventory, lg, ck):
# msg = "Loading SRT files..."
# prog = st.progress(0.0, text=msg)
chks = []
sz = len(inventory)
for i, r in inventory.iterrows():
try:
sr = load_srt(r.id, lg)
except HTTPError as _:
continue
chks += chunk_srt(sr, r.id, lim=ck)
# prog.progress((i+1)/sz, text=msg)
# prog.empty()
return chks


def load_vectors(doc, llm):
embed = OpenAIEmbeddings(model=LLM_MODELS[llm])
return embed.embed_query(doc.page_content)

def select_docs(dt, ch, lg, lm, ck, ct, inventory):
print("loading chunks...")
docs = load_chunks(inventory, lg, ck)
docs_list = [(d,lm) for d in docs]

print("loading vectors...")
with ThreadPool(THREAD_COUNT) as pool:
vectors = pool.starmap(load_vectors, docs_list)

print("number of vectors =", len(vectors))
kmeans = KMeans(n_clusters=ct, random_state=10, n_init=10).fit(vectors)
cent = sorted([np.argmin(np.linalg.norm(vectors - c, axis=1)) for c in kmeans.cluster_centers_])
return [docs[i] for i in cent]

def get_summary(d, llm):
msg = f"""
```{d}```

Create the most prominent headline from the text enclosed in three backticks (```) above, describe it in a paragraph, assign a category to it, determine whether it is of international interest, determine whether it is an advertisement, and assign the top three keywords in the following JSON format:

{{
"title": "<TITLE>",
"description": "<DESCRIPTION>",
"category": "<CATEGORY>",
"international_interest": true|false,
"advertisement": true|false,
"keywords": ["<KEYWORD1>", "<KEYWORD2>", "<KEYWORD3>"]
}}
"""
for delay_secs in (2**x for x in range(0,6)):
try:
res = openai.ChatCompletion.create(
model=LLM_MODELS[llm],
messages=[{"role": "user", "content": msg}]
)
result = json.loads(res.choices[0].message.content.strip())
result = result | d.metadata
result["transcript"] = d.page_content.strip()
break
except openai.error.OpenAIError as e:
print(f"Error: {e}. Retrying in {round(delay_secs, 2)} seconds.")
time.sleep(delay_secs)
continue
return result
Loading