Skip to content

Commit ecd9973

Browse files
committed
chore: resolve suggestions + tests
1 parent e99ea0c commit ecd9973

File tree

4 files changed

+42
-33
lines changed

4 files changed

+42
-33
lines changed

src/notebookllama/Home.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,28 +150,36 @@ def sync_create_podcast(file_content: str):
150150
st.markdown("---")
151151
st.markdown("## NotebookLlaMa - Home🦙")
152152

153+
# Initialize session state BEFORE creating the text input
154+
if "workflow_results" not in st.session_state:
155+
st.session_state.workflow_results = None
156+
if "document_title" not in st.session_state:
157+
st.session_state.document_title = randomname.get_name(
158+
adj=("music_theory", "geometry", "emotions"), noun=("cats", "food")
159+
)
160+
161+
# Use session_state as the value and update it when changed
153162
document_title = st.text_input(
154163
label="Document Title",
155-
value=randomname.get_name(
156-
adj=("music_theory", "geometry", "emotions"), noun=("cats", "food")
157-
),
164+
value=st.session_state.document_title,
165+
key="document_title_input",
158166
)
167+
168+
# Update session state when the input changes
169+
if document_title != st.session_state.document_title:
170+
st.session_state.document_title = document_title
171+
159172
file_input = st.file_uploader(
160173
label="Upload your source PDF file!", accept_multiple_files=False
161174
)
162175

163-
164-
# Initialize session state
165-
if "workflow_results" not in st.session_state:
166-
st.session_state.workflow_results = None
167-
168176
if file_input is not None:
169177
# First button: Process Document
170178
if st.button("Process Document", type="primary"):
171179
with st.spinner("Processing document... This may take a few minutes."):
172180
try:
173181
md_content, summary, q_and_a, bullet_points, mind_map = (
174-
sync_run_workflow(file_input, document_title)
182+
sync_run_workflow(file_input, st.session_state.document_title)
175183
)
176184
st.session_state.workflow_results = {
177185
"md_content": md_content,

src/notebookllama/documents.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@ def __init__(
3737
table_metadata: Optional[MetaData] = None,
3838
):
3939
self.table_name: str = table_name or "documents"
40-
self.table: Optional[Table] = None
40+
self._table: Optional[Table] = None
4141
self._connection: Optional[Connection] = None
42-
self.metadata: Optional[MetaData] = table_metadata or MetaData()
42+
self.metadata: MetaData = cast(MetaData, table_metadata or MetaData())
4343
if engine or engine_url:
44-
self._engine: Union[Engine, str] = engine or engine_url
44+
self._engine: Union[Engine, str] = cast(
45+
Union[Engine, str], engine or engine_url
46+
)
4547
else:
4648
raise ValueError("One of engine or engine_setup_kwargs must be set")
4749

@@ -51,14 +53,20 @@ def connection(self) -> Connection:
5153
self._connect()
5254
return cast(Connection, self._connection)
5355

56+
@property
57+
def table(self) -> Table:
58+
if not self._table:
59+
self._create_table()
60+
return cast(Table, self._table)
61+
5462
def _connect(self) -> None:
5563
# move network calls outside of constructor
5664
if isinstance(self._engine, str):
5765
self._engine = create_engine(self._engine)
5866
self._connection = self._engine.connect()
5967

6068
def _create_table(self) -> None:
61-
self.table = Table(
69+
self._table = Table(
6270
self.table_name,
6371
self.metadata,
6472
Column("id", Integer, primary_key=True, autoincrement=True),
@@ -69,11 +77,9 @@ def _create_table(self) -> None:
6977
Column("mindmap", Text),
7078
Column("bullet_points", Text),
7179
)
72-
self.table.create(self.connection, checkfirst=True)
80+
self._table.create(self.connection, checkfirst=True)
7381

7482
def put_documents(self, documents: List[ManagedDocument]) -> None:
75-
if not self.table:
76-
self._create_table()
7783
for document in documents:
7884
stmt = insert(self.table).values(
7985
document_name=document.document_name,
@@ -87,7 +93,7 @@ def put_documents(self, documents: List[ManagedDocument]) -> None:
8793
self.connection.commit()
8894

8995
def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocument]:
90-
if not self.table_exists:
96+
if self.table is None:
9197
self._create_table()
9298
if not names:
9399
stmt = select(self.table).order_by(self.table.c.id)
@@ -114,7 +120,7 @@ def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocume
114120
return documents
115121

116122
def get_names(self) -> List[str]:
117-
if not self.table_exists:
123+
if self.table is None:
118124
self._create_table()
119125
stmt = select(self.table)
120126
result = self.connection.execute(stmt)
@@ -124,4 +130,7 @@ def get_names(self) -> List[str]:
124130
def disconnect(self) -> None:
125131
if not self._connection:
126132
raise ValueError("Engine was never connected!")
127-
self._engine.dispose(close=True)
133+
if isinstance(self._engine, str):
134+
pass
135+
else:
136+
self._engine.dispose(close=True)

tests/test_document_management.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import List
66

77
from src.notebookllama.documents import DocumentManager, ManagedDocument
8-
from sqlalchemy import text
8+
from sqlalchemy import text, Table
99

1010
ENV = load_dotenv()
1111

@@ -63,10 +63,11 @@ def documents() -> List[ManagedDocument]:
6363
def test_document_manager(documents: List[ManagedDocument]) -> None:
6464
engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}"
6565
manager = DocumentManager(engine_url=engine_url, table_name="test_documents")
66-
assert not manager.table_exists
67-
manager._execute(text("DROP TABLE IF EXISTS test_documents;"))
66+
assert not manager.table
67+
manager.connection.execute(text("DROP TABLE IF EXISTS test_documents;"))
68+
manager.connection.commit()
6869
manager._create_table()
69-
assert manager.table_exists
70+
assert isinstance(manager.table, Table)
7071
manager.put_documents(documents=documents)
7172
names = manager.get_names()
7273
assert names == [doc.document_name for doc in documents]

tests/test_models.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,4 @@ def test_managed_documents() -> None:
197197
mindmap="Hello -> World",
198198
bullet_points=". Hello, . World",
199199
)
200-
assert d2.summary == "Test''s child"
201-
with pytest.raises(ValidationError):
202-
ManagedDocument(
203-
document_name=1,
204-
content="This is a test",
205-
summary="Test's child",
206-
q_and_a="Hello? World.",
207-
mindmap="Hello -> World",
208-
bullet_points=". Hello, . World",
209-
)
200+
assert d2.summary == "Test's child"

0 commit comments

Comments
 (0)