diff --git a/governenv/llm.py b/governenv/llm.py index afda960..0583128 100644 --- a/governenv/llm.py +++ b/governenv/llm.py @@ -3,6 +3,7 @@ """ from openai import OpenAI +from tenacity import retry, stop_after_attempt, wait_exponential from governenv.settings import OPENAI_API_KEY @@ -36,6 +37,9 @@ def _build_prompt( return prompt + @retry( + stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10) + ) def __call__( self, message: str, diff --git a/governenv/prompts.py b/governenv/prompts.py index 44323fa..b4e2816 100644 --- a/governenv/prompts.py +++ b/governenv/prompts.py @@ -2,7 +2,7 @@ Prompts and instructions """ -IDF_PROMPT = """Given the following website HTTP response, determine \ +IDF_PROMPT = """Given the following HTTP website response, determine \ whether it satisfies the following criteria. If it satisfies all three criteria, \ return "Yes". Otherwise, return "No". diff --git a/governenv/utils.py b/governenv/utils.py new file mode 100644 index 0000000..88aed9d --- /dev/null +++ b/governenv/utils.py @@ -0,0 +1,23 @@ +""" +Utility functions +""" + +from governenv.constants import EXKW + + +def kw_filt(data: dict[str, str]) -> dict[str, str]: + """ + Function to filter discussions based on keywords + """ + + return {k: v for k, v in data.items() if not any([i in v for i in EXKW])} + + +def slash_filt(data: dict[str, str]) -> dict[str, str]: + """ + Function to filter discussions based on slashes + """ + + # typically, a discussion has at least 4 levels of slashes + # if the slash count is less than 4, remove the discussion + return {k: v for k, v in data.items() if v.count("/") >= 4} diff --git a/scripts/fetch_html.py b/scripts/fetch_html.py new file mode 100644 index 0000000..8775897 --- /dev/null +++ b/scripts/fetch_html.py @@ -0,0 +1,55 @@ +""" +Fetch the http response of the discussion links +""" + +import pickle +import time +from glob import glob + +import requests +from tqdm import tqdm + +from governenv.constants import DATA_DIR, HEADERS +from governenv.utils import kw_filt, slash_filt + + +def fetch_http_response(url: str, timeout: int = 10) -> str: + """ + Fetches the HTTP response from a given URL. + """ + response = requests.get(url, headers=HEADERS, timeout=timeout) + + # if the status_code is not 200, raise an error + if response.status_code != 200: + raise Exception(f"Status code: {response.status_code}") + + return response.text + + +if __name__ == "__main__": + # unpickle data_unique + with open(DATA_DIR / "discussion_links.pkl", "rb") as f: + data_unique = pickle.load(f) + print(f"Data length before filtering: {len(data_unique)}") + + # filter discussions + data_unique = slash_filt(kw_filt(data_unique)) + print(f"Data length after filtering: {len(data_unique)}") + + fetched_data = [ + _.split("/")[-1].split(".")[0] for _ in glob(str(DATA_DIR / "html" / "*.html")) + ] + + # fetch http response + for i, (k, v) in tqdm(enumerate(data_unique.items()), total=len(data_unique)): + if str(i) in fetched_data: + continue + try: + # save the html + html = fetch_http_response(v) + with open(DATA_DIR / "html_200" / f"{i}.html", "w", encoding="utf-8") as f: + f.write(html) + except Exception as e: + print(f"Error fetching {v}: {e}") + + time.sleep(2) diff --git a/scripts/fetch_http.py b/scripts/fetch_http.py deleted file mode 100644 index c9d67a3..0000000 --- a/scripts/fetch_http.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Fetch the http response of the discussion links -""" - -import asyncio -import gzip -import json -import pickle - -import aiohttp -import requests - -from governenv.constants import DATA_DIR, EXKW, HEADERS - - -def kw_filt(data: dict[str, str]) -> dict[str, str]: - """ - Function to filter discussions based on keywords - """ - - return {k: v for k, v in data.items() if not any([i in v for i in EXKW])} - - -def slash_filt(data: dict[str, str]) -> dict[str, str]: - """ - Function to filter discussions based on slashes - """ - - # typically, a discussion has at least 4 levels of slashes - # if the slash count is less than 4, remove the discussion - return {k: v for k, v in data.items() if v.count("/") >= 4} - - -def fetch_http_response(url: str, timeout: int = 60) -> str: - """ - Fetches the HTTP response from a given URL. - """ - return requests.get(url, headers=HEADERS, timeout=timeout).text - - -async def fetch(session, url: str) -> str: - """ - Fetch the HTTP response from a given URL - """ - async with session.get( - url, ssl=True - ) as response: # Use ssl=True for default SSL context - return ( - await response.text() - ) # Use .text() for HTML/text response or .json() for JSON - - -async def fetch_all( - urls: list[str], time_out: int = 60 -) -> dict[str, str | BaseException]: - """ - Fetch all HTTP responses from a list of URLs - """ - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=time_out) - ) as session: # No need for loop argument - results = await asyncio.gather( - *[fetch(session, url) for url in urls], return_exceptions=True - ) - return { - url: html - for url, html in zip(urls, results) - if not isinstance(html, Exception) - } - - -if __name__ == "__main__": - # unpickle data_unique - with open(DATA_DIR / "discussion_links.pkl", "rb") as f: - data_unique = pickle.load(f) - print(f"Data length before filtering: {len(data_unique)}") - - # filter discussions - data_unique = slash_filt(kw_filt(data_unique)) - print(f"Data length after filtering: {len(data_unique)}") - - urls = list(data_unique.values()) - htmls = asyncio.run(fetch_all(urls)) # Use asyncio.run for a cleaner main loop - - # print the length of the correct responses - print(htmls) - print(len(htmls)) - - # save the htmls to jsonl.gz - with gzip.open(DATA_DIR / "htmls.jsonl.gz", "wt") as f: - for url, html in htmls.items(): - f.write(json.dumps({"url": url, "html": html}) + "\n") diff --git a/scripts/process_html.py b/scripts/process_html.py new file mode 100644 index 0000000..e31cbd9 --- /dev/null +++ b/scripts/process_html.py @@ -0,0 +1,68 @@ +""" +Script to aggregate all the html files in the data folder into a single jsonl file +""" + +import gzip +import json +import pickle +from glob import glob + +from bs4 import BeautifulSoup +from tqdm import tqdm + +from governenv.constants import DATA_DIR +from governenv.utils import kw_filt, slash_filt + + +def distill_html(html: str) -> str: + """ + Function to distill the html + """ + # Parse the HTML + soup = BeautifulSoup(html, "html.parser") + + # Remove irrelevant tags (scripts, styles, footers, navs, etc.) + for tag in soup( + ["script", "style", "header", "footer", "nav", "aside", "form", "link", "meta"] + ): + tag.decompose() + + # Extract text content from discussion-relevant tags + relevant_content = soup.find_all(["div", "p", "li", "article", "section"]) + + # Combine and clean the text + cleaned_text = "\n\n".join( + tag.get_text(strip=True) for tag in relevant_content if tag.get_text(strip=True) + ) + + return cleaned_text + + +if __name__ == "__main__": + + # unpickle data_unique + with open(DATA_DIR / "discussion_links.pkl", "rb") as f: + data_unique = pickle.load(f) + print(f"Data length before filtering: {len(data_unique)}") + + # filter discussions + data_unique = slash_filt(kw_filt(data_unique)) + print(f"Data length after filtering: {len(data_unique)}") + + fetched_data = [ + _.split("/")[-1].split(".")[0] for _ in glob(str(DATA_DIR / "html" / "*.html")) + ] + + # save the html + with gzip.open(DATA_DIR / "html.jsonl.gz", "wt") as gz_f: + for i, (k, v) in tqdm(enumerate(data_unique.items())): + if str(i) in fetched_data: + # save the html + with open(DATA_DIR / "html" / f"{i}.html", "r", encoding="utf-8") as f: + html = f.read() + + # distill the html + html_distilled = distill_html(html) + + json.dump({"url": v, "html": html_distilled}, gz_f) + gz_f.write("\n") diff --git a/scripts/process_identify_html.py b/scripts/process_identify_html.py new file mode 100644 index 0000000..3280572 --- /dev/null +++ b/scripts/process_identify_html.py @@ -0,0 +1,58 @@ +""" +Script to identify whether the html files meet the criteria +""" + +import gzip +import json +import math + + +import tiktoken +from tqdm import tqdm + +from governenv.constants import DATA_DIR +from governenv.llm import ChatGPT +from governenv.prompts import IDF_INSTRUCT, IDF_PROMPT + +tokenizer = tiktoken.encoding_for_model("gpt-4o") + + +idf_dict = {} + +llm = ChatGPT() + +if __name__ == "__main__": + + with gzip.open(DATA_DIR / "html.jsonl.gz", "rt") as gz_f: + for idx, line in tqdm(enumerate(gz_f)): + data = json.loads(line.strip()) + url = data["url"] + html = data["html"] + + try: + # identify if the html meets the criteria + idf_res = llm( + instruction=IDF_INSTRUCT, + message=IDF_PROMPT.format(http_response=html), + logprobs=True, + top_logprobs=2, + ) + + idf, prob = idf_res if isinstance(idf_res, tuple) else (idf_res, None) + + first_prob = prob[0] + yes_prob = ( + math.exp(first_prob.logprob) + if "Yes" in first_prob.token + else 1 - math.exp(first_prob.logprob) + ) + + idf_dict[url] = { + "idf": idf, + "yes_prob": yes_prob, + } + except Exception as e: + print(f"Error processing {url}: {e}") + + with open(DATA_DIR / "idf.json", "w", encoding="utf-8") as f: + json.dump(idf_dict, f, indent=2) diff --git a/scripts/process_sentiment.py b/scripts/process_sentiment.py index 851b1b2..505867e 100644 --- a/scripts/process_sentiment.py +++ b/scripts/process_sentiment.py @@ -8,7 +8,7 @@ from governenv.constants import DATA_DIR, HEADERS from governenv.llm import ChatGPT from governenv.prompts import EVAL_INSTRUCT, EVAL_PROMPT, IDF_INSTRUCT, IDF_PROMPT - +from governenv.utils import kw_filt, slash_filt if __name__ == "__main__":