Skip to content

Commit

Permalink
launch default browser instance
Browse files Browse the repository at this point in the history
  • Loading branch information
itsOwen committed Aug 30, 2024
1 parent c60fc06 commit 63fcb91
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 65 deletions.
14 changes: 11 additions & 3 deletions app/streamlit_web_scraper_chat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import asyncio
import streamlit as st
from src.web_extractor import WebExtractor
from src.scrapers.playwright_scraper import ScraperConfig

class StreamlitWebScraperChat:
def __init__(self, model_name):
self.web_extractor = WebExtractor(model_name=model_name)
def __init__(self, model_name, scraper_config: ScraperConfig = None):
self.web_extractor = WebExtractor(model_name=model_name, scraper_config=scraper_config)

def process_message(self, message: str) -> str:
return asyncio.run(self.web_extractor.process_query(message))
async def process_with_progress():
progress_placeholder = st.empty()
progress_placeholder.text("Connecting to browser...")
result = await self.web_extractor.process_query(message, progress_callback=progress_placeholder.text)
progress_placeholder.empty()
return result

return asyncio.run(process_with_progress())
113 changes: 88 additions & 25 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from io import BytesIO
import re
from src.utils.google_sheets_utils import SCOPES, get_redirect_uri, display_google_sheets_button, initiate_google_auth
from src.scrapers.playwright_scraper import ScraperConfig
import time
from urllib.parse import urlparse
import atexit

def handle_oauth_callback():
if 'code' in st.query_params:
Expand Down Expand Up @@ -57,9 +61,23 @@ def safe_process_message(web_scraper_chat, message):
if message is None or message.strip() == "":
return "I'm sorry, but I didn't receive any input. Could you please try again?"
try:
progress_placeholder = st.empty()
progress_placeholder.text("Initializing scraper...")

start_time = time.time()
response = web_scraper_chat.process_message(message)
end_time = time.time()

progress_placeholder.text(f"Scraping completed in {end_time - start_time:.2f} seconds.")

st.write("Debug: Response type:", type(response))

if isinstance(response, str):
if "Error:" in response:
st.error(response)
else:
st.write("Debug: Response content:", response[:500] + "..." if len(response) > 500 else response)

if isinstance(response, tuple):
st.write("Debug: Response is a tuple")
if len(response) == 2 and isinstance(response[1], pd.DataFrame):
Expand All @@ -69,7 +87,7 @@ def safe_process_message(web_scraper_chat, message):
st.code(csv_string, language="csv")
st.text("Interactive Table:")
st.dataframe(df)

csv_buffer = BytesIO()
df.to_csv(csv_buffer, index=False)
csv_buffer.seek(0)
Expand All @@ -79,46 +97,57 @@ def safe_process_message(web_scraper_chat, message):
file_name="data.csv",
mime="text/csv"
)

return csv_string
elif len(response) == 2 and isinstance(response[0], BytesIO):
st.write("Debug: Excel data detected")
excel_buffer, df = response
st.text("Excel Data:")
st.dataframe(df)

excel_buffer.seek(0)
st.download_button(
label="Download Original Excel file",
data=excel_buffer,
file_name="data_original.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)

excel_data = BytesIO()
with pd.ExcelWriter(excel_data, engine='xlsxwriter') as writer:
df.to_excel(writer, index=False, sheet_name='Sheet1')
excel_data.seek(0)

st.download_button(
label="Download Excel (from DataFrame)",
data=excel_data,
file_name="data_from_df.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)

return ("Excel data displayed and available for download.", excel_buffer)
elif isinstance(response, pd.DataFrame):
st.write("Debug: Response is a DataFrame")
st.text("Data:")
st.dataframe(response)

csv_buffer = BytesIO()
response.to_csv(csv_buffer, index=False)
csv_buffer.seek(0)
st.download_button(
label="Download CSV",
data=csv_buffer,
file_name="data.csv",
mime="text/csv"
)

return "DataFrame displayed and available for download as CSV."
else:
st.write("Debug: Response is not a tuple")
st.write("Debug: Response is not a tuple or DataFrame")

return response
except AttributeError as e:
if "'NoneType' object has no attribute 'lower'" in str(e):
return "I encountered an issue while processing your request. It seems like I received an unexpected empty value. Could you please try rephrasing your input?"
else:
raise e
except Exception as e:
st.write("Debug: Exception occurred:", str(e))
st.error(f"An error occurred during scraping: {str(e)}")
return f"An unexpected error occurred: {str(e)}. Please try again or contact support if the issue persists."

def get_date_group(date_str):
Expand All @@ -144,9 +173,23 @@ def initialize_web_scraper_chat(url=None):
model = OllamaModel(st.session_state.selected_model[7:])
else:
model = st.session_state.selected_model
web_scraper_chat = StreamlitWebScraperChat(model_name=model)

scraper_config = ScraperConfig(
use_current_browser=st.session_state.use_current_browser,
headless=not st.session_state.use_current_browser,
max_retries=3,
delay_after_load=5,
debug=True,
wait_for='domcontentloaded'
)

web_scraper_chat = StreamlitWebScraperChat(model_name=model, scraper_config=scraper_config)
if url:
web_scraper_chat.process_message(url)

website_name = get_website_name(url)
st.session_state.chat_history[st.session_state.current_chat_id]["name"] = website_name

return web_scraper_chat

async def list_ollama_models():
Expand All @@ -164,6 +207,13 @@ def get_image_base64(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode()

def get_website_name(url: str) -> str:
parsed_url = urlparse(url)
domain = parsed_url.netloc
if domain.startswith('www.'):
domain = domain[4:]
return domain.split('.')[0].capitalize()

def render_message(role, content, avatar_path):
message_class = "user-message" if role == "user" else "assistant-message"
avatar_base64 = get_image_base64(avatar_path)
Expand Down Expand Up @@ -227,6 +277,12 @@ def display_message_with_sheets_upload(message, message_index):
else:
st.markdown(str(content))

def cleanup():
if 'web_scraper_chat' in st.session_state and st.session_state.web_scraper_chat:
del st.session_state.web_scraper_chat

atexit.register(cleanup)

def main():

st.set_page_config(
Expand Down Expand Up @@ -270,12 +326,14 @@ def main():
ollama_models = st.session_state.get('ollama_models', [])
all_models = default_models + [f"ollama:{model}" for model in ollama_models]
selected_model = st.selectbox("Choose a model", all_models, index=all_models.index(st.session_state.selected_model) if st.session_state.selected_model in all_models else 0)

if selected_model != st.session_state.selected_model:
st.session_state.selected_model = selected_model
st.session_state.web_scraper_chat = None
st.rerun()

st.session_state.use_current_browser = st.checkbox("Use Current Browser (No Docker)", value=False, help="Works Natively, Doesn't Work with Docker. if a website is blocking your browser, you can use this option to use the current browser instead of opening a new one.")

if st.button("Refresh Ollama Models"):
with st.spinner("Fetching Ollama models..."):
st.session_state.ollama_models = asyncio.run(list_ollama_models())
Expand Down Expand Up @@ -304,22 +362,19 @@ def main():
for date_group, chats in grouped_chats.items():
st.markdown(f"<div class='date-group'>{date_group}</div>", unsafe_allow_html=True)
for chat_id, chat_data in chats:
messages = chat_data['messages']
if messages:
button_label = chat_data.get('name', f"{messages[0]['content'][:25]}...")
else:
button_label = chat_data.get('name', "🗨️ Empty Chat")
button_label = chat_data.get('name', "🗨️ Unnamed Chat")

col1, col2 = st.columns([0.85, 0.15])
col1, col2 = st.columns([0.85, 0.15])

with col1:
if st.button(button_label, key=f"history_{chat_id}", use_container_width=True):
st.session_state.current_chat_id = chat_id
messages = chat_data['messages']
last_url = get_last_url_from_chat(messages)
if last_url and not st.session_state.web_scraper_chat:
st.session_state.web_scraper_chat = initialize_web_scraper_chat(last_url)
st.rerun()

with col2:
if st.button("🗑️", key=f"delete_{chat_id}"):
del st.session_state.chat_history[chat_id]
Expand Down Expand Up @@ -370,17 +425,25 @@ def main():

if prompt:
st.session_state.chat_history[st.session_state.current_chat_id]["messages"].append({"role": "user", "content": prompt})

if not st.session_state.web_scraper_chat:
st.session_state.web_scraper_chat = initialize_web_scraper_chat()

if prompt.lower().startswith("http"):
website_name = get_website_name(prompt)
st.session_state.chat_history[st.session_state.current_chat_id]["name"] = website_name
st.info(f"Scraping {website_name}... This may take a moment.")

with st.chat_message("assistant"):
try:
full_response = loading_animation(
safe_process_message,
st.session_state.web_scraper_chat,
prompt
)
if isinstance(full_response, str) and not full_response.startswith("Error:"):
st.success("Scraping completed successfully!")

st.write("Debug: Full response type:", type(full_response))
if full_response is not None:
if isinstance(full_response, tuple) and len(full_response) == 2 and isinstance(full_response[1], BytesIO):
Expand All @@ -390,7 +453,7 @@ def main():
save_chat_history(st.session_state.chat_history)
except Exception as e:
st.error(f"An unexpected error occurred: {str(e)}")

save_chat_history(st.session_state.chat_history)
st.rerun()

st.markdown(
Expand Down
Loading

0 comments on commit 63fcb91

Please sign in to comment.