Skip to content

Commit 69512f8

Browse files
authored
Merge pull request #31 from run-llama/clelia/utils-refactoring-and-document-management
Refactoring utils.py and creating a document management UI
2 parents 34b4460 + ec7479c commit 69512f8

17 files changed

+778
-337
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "notebookllama"
3-
version = "0.3.1"
3+
version = "0.4.0"
44
description = "An OSS and LlamaCloud-backed alternative to NotebookLM"
55
readme = "README.md"
66
requires-python = ">=3.13"
@@ -34,6 +34,7 @@ dependencies = [
3434
"pytest-asyncio>=1.0.0",
3535
"python-dotenv>=1.1.1",
3636
"pyvis>=0.3.2",
37+
"randomname>=0.2.1",
3738
"streamlit>=1.46.1",
3839
"textual>=3.7.1"
3940
]

src/notebookllama/Home.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from dotenv import load_dotenv
77
import sys
88
import time
9+
import randomname
910
import streamlit.components.v1 as components
1011

1112
from pathlib import Path
13+
from documents import ManagedDocument, DocumentManager
1214
from audio import PODCAST_GEN, PodcastConfig
1315
from typing import Tuple
1416
from workflow import NotebookLMWorkflow, FileInputEvent, NotebookOutputEvent
@@ -29,11 +31,13 @@
2931
span_exporter=span_exporter,
3032
debug=True,
3133
)
34+
engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}"
3235
sql_engine = OtelTracesSqlEngine(
33-
engine_url=f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}",
36+
engine_url=engine_url,
3437
table_name="agent_traces",
3538
service_name="agent.traces",
3639
)
40+
document_manager = DocumentManager(engine_url=engine_url)
3741

3842
WF = NotebookLMWorkflow(timeout=600)
3943

@@ -44,7 +48,9 @@ def read_html_file(file_path: str) -> str:
4448
return f.read()
4549

4650

47-
async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]:
51+
async def run_workflow(
52+
file: io.BytesIO, document_title: str
53+
) -> Tuple[str, str, str, str, str]:
4854
# Create temp file with proper Windows handling
4955
with temp.NamedTemporaryFile(suffix=".pdf", delete=False) as fl:
5056
content = file.getvalue()
@@ -72,6 +78,18 @@ async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]:
7278

7379
end_time = int(time.time() * 1000000)
7480
sql_engine.to_sql_database(start_time=st_time, end_time=end_time)
81+
document_manager.put_documents(
82+
[
83+
ManagedDocument(
84+
document_name=document_title,
85+
content=result.md_content,
86+
summary=result.summary,
87+
q_and_a=q_and_a,
88+
mindmap=mind_map,
89+
bullet_points=bullet_points,
90+
)
91+
]
92+
)
7593
return result.md_content, result.summary, q_and_a, bullet_points, mind_map
7694

7795
finally:
@@ -85,7 +103,7 @@ async def run_workflow(file: io.BytesIO) -> Tuple[str, str, str, str, str]:
85103
pass # Give up if still locked
86104

87105

88-
def sync_run_workflow(file: io.BytesIO):
106+
def sync_run_workflow(file: io.BytesIO, document_title: str):
89107
try:
90108
# Try to use existing event loop
91109
loop = asyncio.get_event_loop()
@@ -94,15 +112,17 @@ def sync_run_workflow(file: io.BytesIO):
94112
import concurrent.futures
95113

96114
with concurrent.futures.ThreadPoolExecutor() as executor:
97-
future = executor.submit(asyncio.run, run_workflow(file))
115+
future = executor.submit(
116+
asyncio.run, run_workflow(file, document_title)
117+
)
98118
return future.result()
99119
else:
100-
return loop.run_until_complete(run_workflow(file))
120+
return loop.run_until_complete(run_workflow(file, document_title))
101121
except RuntimeError:
102122
# No event loop exists, create one
103123
if sys.platform == "win32":
104124
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
105-
return asyncio.run(run_workflow(file))
125+
return asyncio.run(run_workflow(file, document_title))
106126

107127

108128
async def create_podcast(file_content: str, config: PodcastConfig = None):
@@ -132,23 +152,36 @@ def sync_create_podcast(file_content: str, config: PodcastConfig = None):
132152
st.markdown("---")
133153
st.markdown("## NotebookLlaMa - Home🦙")
134154

135-
file_input = st.file_uploader(
136-
label="Upload your source PDF file!", accept_multiple_files=False
155+
# Initialize session state BEFORE creating the text input
156+
if "workflow_results" not in st.session_state:
157+
st.session_state.workflow_results = None
158+
if "document_title" not in st.session_state:
159+
st.session_state.document_title = randomname.get_name(
160+
adj=("music_theory", "geometry", "emotions"), noun=("cats", "food")
161+
)
162+
163+
# Use session_state as the value and update it when changed
164+
document_title = st.text_input(
165+
label="Document Title",
166+
value=st.session_state.document_title,
167+
key="document_title_input",
137168
)
138169

139-
# Add this after your existing code, before the st.title line:
170+
# Update session state when the input changes
171+
if document_title != st.session_state.document_title:
172+
st.session_state.document_title = document_title
140173

141-
# Initialize session state
142-
if "workflow_results" not in st.session_state:
143-
st.session_state.workflow_results = None
174+
file_input = st.file_uploader(
175+
label="Upload your source PDF file!", accept_multiple_files=False
176+
)
144177

145178
if file_input is not None:
146179
# First button: Process Document
147180
if st.button("Process Document", type="primary"):
148181
with st.spinner("Processing document... This may take a few minutes."):
149182
try:
150183
md_content, summary, q_and_a, bullet_points, mind_map = (
151-
sync_run_workflow(file_input)
184+
sync_run_workflow(file_input, st.session_state.document_title)
152185
)
153186
st.session_state.workflow_results = {
154187
"md_content": md_content,

src/notebookllama/documents.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from dataclasses import dataclass
2+
from sqlalchemy import (
3+
Table,
4+
MetaData,
5+
Column,
6+
Text,
7+
Integer,
8+
create_engine,
9+
Engine,
10+
Connection,
11+
insert,
12+
select,
13+
)
14+
from typing import Optional, List, cast, Union
15+
16+
17+
def apply_string_correction(string: str) -> str:
18+
return string.replace("''", "'").replace('""', '"')
19+
20+
21+
@dataclass
22+
class ManagedDocument:
23+
document_name: str
24+
content: str
25+
summary: str
26+
q_and_a: str
27+
mindmap: str
28+
bullet_points: str
29+
30+
31+
class DocumentManager:
32+
def __init__(
33+
self,
34+
engine: Optional[Engine] = None,
35+
engine_url: Optional[str] = None,
36+
table_name: Optional[str] = None,
37+
table_metadata: Optional[MetaData] = None,
38+
):
39+
self.table_name: str = table_name or "documents"
40+
self._table: Optional[Table] = None
41+
self._connection: Optional[Connection] = None
42+
self.metadata: MetaData = cast(MetaData, table_metadata or MetaData())
43+
if engine or engine_url:
44+
self._engine: Union[Engine, str] = cast(
45+
Union[Engine, str], engine or engine_url
46+
)
47+
else:
48+
raise ValueError("One of engine or engine_setup_kwargs must be set")
49+
50+
@property
51+
def connection(self) -> Connection:
52+
if not self._connection:
53+
self._connect()
54+
return cast(Connection, self._connection)
55+
56+
@property
57+
def table(self) -> Table:
58+
if self._table is None:
59+
self._create_table()
60+
return cast(Table, self._table)
61+
62+
def _connect(self) -> None:
63+
# move network calls outside of constructor
64+
if isinstance(self._engine, str):
65+
self._engine = create_engine(self._engine)
66+
self._connection = self._engine.connect()
67+
68+
def _create_table(self) -> None:
69+
self._table = Table(
70+
self.table_name,
71+
self.metadata,
72+
Column("id", Integer, primary_key=True, autoincrement=True),
73+
Column("document_name", Text),
74+
Column("content", Text),
75+
Column("summary", Text),
76+
Column("q_and_a", Text),
77+
Column("mindmap", Text),
78+
Column("bullet_points", Text),
79+
)
80+
self._table.create(self.connection, checkfirst=True)
81+
82+
def put_documents(self, documents: List[ManagedDocument]) -> None:
83+
for document in documents:
84+
stmt = insert(self.table).values(
85+
document_name=document.document_name,
86+
content=document.content,
87+
summary=document.summary,
88+
q_and_a=document.q_and_a,
89+
mindmap=document.mindmap,
90+
bullet_points=document.bullet_points,
91+
)
92+
self.connection.execute(stmt)
93+
self.connection.commit()
94+
95+
def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocument]:
96+
if self.table is None:
97+
self._create_table()
98+
if not names:
99+
stmt = select(self.table).order_by(self.table.c.id)
100+
else:
101+
stmt = (
102+
select(self.table)
103+
.where(self.table.c.document_name.in_(names))
104+
.order_by(self.table.c.id)
105+
)
106+
result = self.connection.execute(stmt)
107+
rows = result.fetchall()
108+
documents = []
109+
for row in rows:
110+
documents.append(
111+
ManagedDocument(
112+
document_name=row.document_name,
113+
content=row.content,
114+
summary=row.summary,
115+
q_and_a=row.q_and_a,
116+
mindmap=row.mindmap,
117+
bullet_points=row.bullet_points,
118+
)
119+
)
120+
return documents
121+
122+
def get_names(self) -> List[str]:
123+
if self.table is None:
124+
self._create_table()
125+
stmt = select(self.table)
126+
result = self.connection.execute(stmt)
127+
rows = result.fetchall()
128+
return [row.document_name for row in rows]
129+
130+
def disconnect(self) -> None:
131+
if not self._connection:
132+
raise ValueError("Engine was never connected!")
133+
if isinstance(self._engine, str):
134+
pass
135+
else:
136+
self._engine.dispose(close=True)

src/notebookllama/mindmap.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import uuid
2+
import os
3+
import warnings
4+
import json
5+
from pydantic import BaseModel, Field, model_validator
6+
from typing_extensions import Self
7+
from typing import List, Union
8+
9+
from pyvis.network import Network
10+
from llama_index.core.llms import ChatMessage
11+
from llama_index.llms.openai import OpenAIResponses
12+
13+
14+
class Node(BaseModel):
15+
id: str
16+
content: str
17+
18+
19+
class Edge(BaseModel):
20+
from_id: str
21+
to_id: str
22+
23+
24+
class MindMap(BaseModel):
25+
nodes: List[Node] = Field(
26+
description="List of nodes in the mind map, each represented as a Node object with an 'id' and concise 'content' (no more than 5 words).",
27+
examples=[
28+
[
29+
Node(id="A", content="Fall of the Roman Empire"),
30+
Node(id="B", content="476 AD"),
31+
Node(id="C", content="Barbarian invasions"),
32+
],
33+
[
34+
Node(id="A", content="Auxin is released"),
35+
Node(id="B", content="Travels to the roots"),
36+
Node(id="C", content="Root cells grow"),
37+
],
38+
],
39+
)
40+
edges: List[Edge] = Field(
41+
description="The edges connecting the nodes of the mind map, as a list of Edge objects with from_id and to_id fields representing the source and target node IDs.",
42+
examples=[
43+
[
44+
Edge(from_id="A", to_id="B"),
45+
Edge(from_id="A", to_id="C"),
46+
Edge(from_id="B", to_id="C"),
47+
],
48+
[
49+
Edge(from_id="C", to_id="A"),
50+
Edge(from_id="B", to_id="C"),
51+
Edge(from_id="A", to_id="B"),
52+
],
53+
],
54+
)
55+
56+
@model_validator(mode="after")
57+
def validate_mind_map(self) -> Self:
58+
all_nodes = [el.id for el in self.nodes]
59+
all_edges = [el.from_id for el in self.edges] + [el.to_id for el in self.edges]
60+
if set(all_nodes).issubset(set(all_edges)) and set(all_nodes) != set(all_edges):
61+
raise ValueError(
62+
"There are non-existing nodes listed as source or target in the edges"
63+
)
64+
return self
65+
66+
67+
class MindMapCreationFailedWarning(Warning):
68+
"""A warning returned if the mind map creation failed"""
69+
70+
71+
if os.getenv("OPENAI_API_KEY", None):
72+
LLM = OpenAIResponses(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY"))
73+
LLM_STRUCT = LLM.as_structured_llm(MindMap)
74+
75+
76+
async def get_mind_map(summary: str, highlights: List[str]) -> Union[str, None]:
77+
try:
78+
keypoints = "\n- ".join(highlights)
79+
messages = [
80+
ChatMessage(
81+
role="user",
82+
content=f"This is the summary for my document: {summary}\n\nAnd these are the key points:\n- {keypoints}",
83+
)
84+
]
85+
response = await LLM_STRUCT.achat(messages=messages)
86+
response_json = json.loads(response.message.content)
87+
net = Network(directed=True, height="750px", width="100%")
88+
net.set_options("""
89+
var options = {
90+
"physics": {
91+
"enabled": false
92+
}
93+
}
94+
""")
95+
nodes = response_json["nodes"]
96+
edges = response_json["edges"]
97+
for node in nodes:
98+
net.add_node(n_id=node["id"], label=node["content"])
99+
for edge in edges:
100+
net.add_edge(source=edge["from_id"], to=edge["to_id"])
101+
name = str(uuid.uuid4())
102+
net.save_graph(name + ".html")
103+
return name + ".html"
104+
except Exception as e:
105+
warnings.warn(
106+
message=f"An error occurred during the creation of the mind map: {e}",
107+
category=MindMapCreationFailedWarning,
108+
)
109+
return None

0 commit comments

Comments
 (0)