diff --git a/Cargo.lock b/Cargo.lock index 56814d3..c01be13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -152,6 +152,17 @@ dependencies = [ "syn", ] +[[package]] +name = "async-trait" +version = "0.1.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -650,6 +661,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -671,6 +693,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1081,16 +1104,16 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "learner" -version = "0.6.0" +version = "0.7.0" dependencies = [ "anyhow", + "async-trait", "chrono", - "crossterm", "dirs", + "futures", "lazy_static", "lopdf", "quick-xml", - "ratatui", "regex", "reqwest", "rusqlite", diff --git a/Cargo.toml b/Cargo.toml index c01eaf5..b9cb1d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ repository = "https://github.com/autoparallel/learner" [workspace.dependencies] # local -learner = { path = "crates/learner", version = "=0.6.0" } +learner = { path = "crates/learner", version = "=0.7.0" } # shared dependencies serde = { version = "1.0", features = ["derive"] } @@ -28,8 +28,10 @@ tokio = { version = "1.41", features = [ tracing = { version = "0.1" } # learner dependencies +async-trait = { version = "0.1" } chrono = { version = "0.4", features = ["serde"] } dirs = { version = "5.0" } +futures = { version = "0.3.31" } lazy_static = { version = "1.5" } lopdf = { version = "0.34" } quick-xml = { version = "0.37", features = ["serialize"] } diff --git a/crates/learner/Cargo.toml b/crates/learner/Cargo.toml index 2f7ae3d..0abdbbb 100644 --- a/crates/learner/Cargo.toml +++ b/crates/learner/Cargo.toml @@ -8,15 +8,13 @@ license.workspace = true name = "learner" readme.workspace = true repository.workspace = true -version = "0.6.0" - -[features] -default = [] -tui = ["dep:ratatui", "dep:crossterm"] +version = "0.7.0" [dependencies] +async-trait = { workspace = true } chrono = { workspace = true } dirs = { workspace = true } +futures = { workspace = true } lazy_static = { workspace = true } lopdf = { workspace = true } quick-xml = { workspace = true } @@ -30,10 +28,6 @@ tokio = { workspace = true } tokio-rusqlite = { workspace = true } tracing = { workspace = true } -# TUI dependencies (optional) -crossterm = { workspace = true, optional = true } -ratatui = { workspace = true, optional = true } - [dev-dependencies] anyhow = { workspace = true } tempfile = { workspace = true } diff --git a/crates/learner/src/database.rs b/crates/learner/src/database.rs deleted file mode 100644 index fd16292..0000000 --- a/crates/learner/src/database.rs +++ /dev/null @@ -1,762 +0,0 @@ -//! Local SQLite database management for storing and retrieving papers. -//! -//! This module provides functionality to persist paper metadata in a local SQLite database. -//! It supports: -//! - Paper metadata storage and retrieval -//! - Author information management -//! - Full-text search across papers -//! - Source-specific identifier lookups -//! -//! The database schema is automatically initialized when opening a database, and includes -//! tables for papers, authors, and full-text search indexes. -//! -//! # Examples -//! -//! ```no_run -//! # async fn example() -> Result<(), Box> { -//! // Open or create a database -//! let db = learner::database::Database::open("papers.db").await?; -//! -//! // Fetch and save a paper -//! let paper = learner::paper::Paper::new("2301.07041").await?; -//! let id = db.save_paper(&paper).await?; -//! -//! // Search for papers -//! let results = db.search_papers("neural networks").await?; -//! for paper in results { -//! println!("Found: {}", paper.title); -//! } -//! # Ok(()) -//! # } -//! ``` - -use rusqlite::params; -use tokio_rusqlite::Connection; - -use super::*; - -/// Handle for interacting with the paper database. -/// -/// This struct manages an async connection to a SQLite database and provides -/// methods for storing and retrieving paper metadata. It uses SQLite's full-text -/// search capabilities for efficient paper discovery. -/// -/// The database is automatically initialized with the required schema when opened. -/// If the database file doesn't exist, it will be created. -pub struct Database { - /// Async SQLite connection handle - pub conn: Connection, -} - -impl Database { - /// Opens an existing database or creates a new one at the specified path. - /// - /// This method will: - /// 1. Create the database file if it doesn't exist - /// 2. Initialize the schema using migrations - /// 3. Set up full-text search indexes - /// - /// # Arguments - /// - /// * `path` - Path where the database file should be created or opened - /// - /// # Returns - /// - /// Returns a [`Result`] containing either: - /// - A [`Database`] handle for database operations - /// - A [`LearnerError`] if database creation or initialization fails - /// - /// # Examples - /// - /// ```no_run - /// # use learner::database::Database; - /// # async fn example() -> Result<(), Box> { - /// // Open in a specific location - /// let db = Database::open("papers.db").await?; - /// - /// // Or use the default location - /// let db = Database::open(Database::default_path()).await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn open(path: impl AsRef) -> Result { - let conn = Connection::open(path.as_ref()).await?; - - // Initialize schema - conn - .call(|conn| { - conn.execute_batch(include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/migrations/init.sql" - )))?; - Ok(()) - }) - .await?; - - Ok(Self { conn }) - } - - /// Returns the default path for the database file. - /// - /// The path is constructed as follows: - /// - On Unix: `~/.local/share/learner/learner.db` - /// - On macOS: `~/Library/Application Support/learner/learner.db` - /// - On Windows: `%APPDATA%\learner\learner.db` - /// - Fallback: `./learner.db` in the current directory - /// - /// # Examples - /// - /// ```no_run - /// let path = learner::database::Database::default_path(); - /// println!("Database will be stored at: {}", path.display()); - /// ``` - pub fn default_path() -> PathBuf { - dirs::data_dir().unwrap_or_else(|| PathBuf::from(".")).join("learner").join("learner.db") - } - - /// Saves a paper and its authors to the database. - /// - /// This method will: - /// 1. Insert the paper's metadata into the papers table - /// 2. Insert all authors into the authors table - /// 3. Update the full-text search index - /// - /// The operation is performed in a transaction to ensure data consistency. - /// - /// # Arguments - /// - /// * `paper` - The paper to save - /// - /// # Returns - /// - /// Returns a [`Result`] containing either: - /// - The database ID of the saved paper - /// - A [`LearnerError`] if the save operation fails - /// - /// # Examples - /// - /// ```no_run - /// # use learner::{database::Database, paper::Paper}; - /// # async fn example() -> Result<(), Box> { - /// let db = Database::open("papers.db").await?; - /// let paper = Paper::new("2301.07041").await?; - /// let id = db.save_paper(&paper).await?; - /// println!("Saved paper with ID: {}", id); - /// # Ok(()) - /// # } - /// ``` - pub async fn save_paper(&self, paper: &Paper) -> Result { - let paper = paper.clone(); - self - .conn - .call(move |conn| { - let tx = conn.transaction()?; - - // Insert paper - let paper_id = { - let mut stmt = tx.prepare_cached( - "INSERT INTO papers ( - title, abstract_text, publication_date, - source, source_identifier, pdf_url, doi - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) - RETURNING id", - )?; - - stmt.query_row( - params![ - &paper.title, - &paper.abstract_text, - &paper.publication_date, - paper.source.to_string(), - &paper.source_identifier, - &paper.pdf_url, - &paper.doi, - ], - |row| row.get::<_, i64>(0), - )? - }; - - // Insert authors - { - let mut stmt = tx.prepare_cached( - "INSERT INTO authors (paper_id, name, affiliation, email) - VALUES (?1, ?2, ?3, ?4)", - )?; - - for author in &paper.authors { - stmt.execute(params![paper_id, &author.name, &author.affiliation, &author.email,])?; - } - } - - tx.commit()?; - Ok(paper_id) - }) - .await - .map_err(LearnerError::from) - } - - /// Retrieves a paper using its source and identifier. - /// - /// This method looks up a paper based on its origin (e.g., arXiv, DOI) - /// and its source-specific identifier. It also fetches all associated - /// author information. - /// - /// # Arguments - /// - /// * `source` - The paper's source system (arXiv, IACR, DOI) - /// * `source_id` - The source-specific identifier - /// - /// # Returns - /// - /// Returns a [`Result`] containing either: - /// - `Some(Paper)` if found - /// - `None` if no matching paper exists - /// - A [`LearnerError`] if the query fails - /// - /// # Examples - /// - /// ```no_run - /// # use learner::{database::Database, paper::Source}; - /// # async fn example() -> Result<(), Box> { - /// let db = Database::open("papers.db").await?; - /// if let Some(paper) = db.get_paper_by_source_id(&Source::Arxiv, "2301.07041").await? { - /// println!("Found paper: {}", paper.title); - /// } - /// # Ok(()) - /// # } - /// ``` - pub async fn get_paper_by_source_id( - &self, - source: &Source, - source_id: &str, - ) -> Result> { - // Clone the values before moving into the async closure - let source = source.to_string(); - let source_id = source_id.to_string(); - - self - .conn - .call(move |conn| { - let mut paper_stmt = conn.prepare_cached( - "SELECT id, title, abstract_text, publication_date, source, - source_identifier, pdf_url, doi - FROM papers - WHERE source = ?1 AND source_identifier = ?2", - )?; - - let mut author_stmt = conn.prepare_cached( - "SELECT name, affiliation, email - FROM authors - WHERE paper_id = ?", - )?; - - let paper_result = paper_stmt.query_row(params![source, source_id], |row| { - Ok(Paper { - title: row.get(1)?, - abstract_text: row.get(2)?, - publication_date: row.get(3)?, - source: Source::from_str(&row.get::<_, String>(4)?).map_err(|e| { - rusqlite::Error::FromSqlConversionFailure(4, rusqlite::types::Type::Text, Box::new(e)) - })?, - source_identifier: row.get(5)?, - pdf_url: row.get(6)?, - doi: row.get(7)?, - authors: Vec::new(), // Filled in below - }) - }); - - match paper_result { - Ok(mut paper) => { - let paper_id: i64 = - paper_stmt.query_row(params![source, source_id], |row| row.get(0))?; - - let authors = author_stmt.query_map([paper_id], |row| { - Ok(Author { - name: row.get(0)?, - affiliation: row.get(1)?, - email: row.get(2)?, - }) - })?; - - paper.authors = authors.collect::, _>>()?; - Ok(Some(paper)) - }, - Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), - Err(e) => Err(e.into()), - } - }) - .await - .map_err(LearnerError::from) - } - - /// Searches for papers using full-text search. - /// - /// This method uses SQLite's FTS5 module to perform full-text search across: - /// - Paper titles - /// - Paper abstracts - /// - /// Results are ordered by relevance using FTS5's built-in ranking algorithm. - /// - /// # Arguments - /// - /// * `query` - The search query using FTS5 syntax - /// - /// # Returns - /// - /// Returns a [`Result`] containing either: - /// - A vector of matching papers - /// - A [`LearnerError`] if the search fails - /// - /// # Examples - /// - /// ```no_run - /// # async fn example() -> Result<(), Box> { - /// let db = learner::database::Database::open("papers.db").await?; - /// - /// // Simple word search - /// let papers = db.search_papers("quantum").await?; - /// - /// // Phrase search - /// let papers = db.search_papers("\"neural networks\"").await?; - /// - /// // Complex query - /// let papers = db.search_papers("machine learning NOT regression").await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn search_papers(&self, query: &str) -> Result> { - let query = query.to_lowercase(); // Make search case-insensitive - - self - .conn - .call(move |conn| { - // First get all paper IDs matching the search - let mut id_stmt = conn.prepare_cached( - "SELECT p.id - FROM papers p - JOIN papers_fts f ON p.id = f.rowid - WHERE papers_fts MATCH ?1 - ORDER BY rank", - )?; - - // Collect matching IDs first - let paper_ids: Vec = id_stmt - .query_map([&query], |row| row.get(0))? - .collect::, _>>()?; - - let mut papers = Vec::new(); - - // Now fetch complete paper data for each ID - for paper_id in paper_ids { - // Get paper details - let mut paper_stmt = conn.prepare_cached( - "SELECT title, abstract_text, publication_date, - source, source_identifier, pdf_url, doi - FROM papers - WHERE id = ?", - )?; - - let paper = paper_stmt.query_row([paper_id], |row| { - Ok(Paper { - title: row.get(0)?, - abstract_text: row.get(1)?, - publication_date: row.get(2)?, - source: Source::from_str(&row.get::<_, String>(3)?).map_err(|e| { - rusqlite::Error::FromSqlConversionFailure( - 3, - rusqlite::types::Type::Text, - Box::new(e), - ) - })?, - source_identifier: row.get(4)?, - pdf_url: row.get(5)?, - doi: row.get(6)?, - authors: Vec::new(), - }) - })?; - - // Get authors for this paper - let mut author_stmt = conn.prepare_cached( - "SELECT name, affiliation, email - FROM authors - WHERE paper_id = ?", - )?; - - let authors = author_stmt - .query_map([paper_id], |row| { - Ok(Author { - name: row.get(0)?, - affiliation: row.get(1)?, - email: row.get(2)?, - }) - })? - .collect::, _>>()?; - - // Create the complete paper with authors - let mut paper = paper; - paper.authors = authors; - papers.push(paper); - } - - Ok(papers) - }) - .await - .map_err(LearnerError::from) - } - - /// Returns the default path for PDF storage. - /// - /// The path is constructed as follows: - /// - On Unix: `~/Documents/learner/papers` - /// - On macOS: `~/Documents/learner/papers` - /// - On Windows: `Documents\learner\papers` - /// - Fallback: `./papers` in the current directory - /// - /// # Examples - /// - /// ```no_run - /// let path = learner::database::Database::default_pdf_path(); - /// println!("PDFs will be stored at: {}", path.display()); - /// ``` - pub fn default_pdf_path() -> PathBuf { - dirs::document_dir().unwrap_or_else(|| PathBuf::from(".")).join("learner").join("papers") - } - - /// Sets a configuration value in the database. - /// - /// # Arguments - /// - /// * `key` - The configuration key - /// * `value` - The value to store - /// - /// # Returns - /// - /// Returns a [`Result`] indicating success or failure - pub async fn set_config(&self, key: &str, value: &str) -> Result<()> { - let key = key.to_string(); - let value = value.to_string(); - self - .conn - .call(move |conn| { - Ok( - conn - .execute("INSERT OR REPLACE INTO config (key, value) VALUES (?1, ?2)", params![ - key, value - ]) - .map(|_| ()), - ) - }) - .await? - .map_err(LearnerError::from) - } - - /// Gets a configuration value from the database. - /// - /// # Arguments - /// - /// * `key` - The configuration key to retrieve - /// - /// # Returns - /// - /// Returns a [`Result`] containing either: - /// - Some(String) with the configuration value - /// - None if the key doesn't exist - pub async fn get_config(&self, key: &str) -> Result> { - let key = key.to_string(); - self - .conn - .call(move |conn| { - let mut stmt = conn.prepare_cached("SELECT value FROM config WHERE key = ?1")?; - - let result = stmt.query_row([key], |row| row.get::<_, String>(0)); - - match result { - Ok(value) => Ok(Some(value)), - Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), - Err(e) => Err(e.into()), - } - }) - .await - .map_err(LearnerError::from) - } - - /// Records a PDF file location and status for a paper. - /// - /// # Arguments - /// - /// * `paper_id` - The database ID of the paper - /// * `path` - Full path to the file - /// * `filename` - The filename - /// * `status` - Download status ('success', 'failed', 'pending') - /// * `error` - Optional error message if download failed - /// - /// # Returns - /// - /// Returns a [`Result`] containing the file ID on success - pub async fn record_pdf( - &self, - paper_id: i64, - path: PathBuf, - filename: String, - status: &str, - error: Option, - ) -> Result { - let path_str = path.to_string_lossy().to_string(); - let status = status.to_string(); - - self - .conn - .call(move |conn| { - let tx = conn.transaction()?; - - let id = tx.query_row( - "INSERT OR REPLACE INTO files ( - paper_id, path, filename, download_status, error_message - ) VALUES (?1, ?2, ?3, ?4, ?5) - RETURNING id", - params![paper_id, path_str, filename, status, error], - |row| row.get(0), - )?; - - tx.commit()?; - Ok(id) - }) - .await - .map_err(LearnerError::from) - } - - /// Gets the PDF status for a paper. - /// - /// # Arguments - /// - /// * `paper_id` - The database ID of the paper - /// - /// # Returns - /// - /// Returns a [`Result`] containing either: - /// - Some((PathBuf, String, String, Option)) with the path, filename, status, and error - /// - None if no PDF entry exists - pub async fn get_pdf_status( - &self, - paper_id: i64, - ) -> Result)>> { - self - .conn - .call(move |conn| { - let mut stmt = conn.prepare_cached( - "SELECT path, filename, download_status, error_message FROM files - WHERE paper_id = ?1", - )?; - - let result = stmt.query_row([paper_id], |row| { - Ok(( - PathBuf::from(row.get::<_, String>(0)?), - row.get::<_, String>(1)?, - row.get::<_, String>(2)?, - row.get::<_, Option>(3)?, - )) - }); - - match result { - Ok(info) => Ok(Some(info)), - Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), - Err(e) => Err(e.into()), - } - }) - .await - .map_err(LearnerError::from) - } - - // TODO (autoparallel): Things like this should have some kind of enum to select, not `&str` - /// Lists all papers with optional ordering. - /// - /// Retrieves all papers from the database with configurable sorting. - /// The papers are returned with their complete metadata including authors. - /// - /// # Arguments - /// - /// * `order_by` - Field to order by ("title", "date", "source") - /// * `desc` - Whether to sort in descending order - /// - /// # Returns - /// - /// Returns a Result containing a vector of papers ordered as specified. - /// - /// # Examples - /// - /// ```no_run - /// # use learner::database::Database; - /// # async fn example() -> Result<(), Box> { - /// let db = Database::open("papers.db").await?; - /// - /// // Get papers ordered by title - /// let papers = db.list_papers("title", false).await?; - /// - /// // Get papers ordered by date, newest first - /// let papers = db.list_papers("date", true).await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn list_papers(&self, order_by: &str, desc: bool) -> Result> { - // Validate order_by field - let order_clause = match order_by.to_lowercase().as_str() { - "title" => "p.title", - "date" => "p.publication_date", - "source" => "p.source, p.source_identifier", - _ => return Err(LearnerError::Database("Invalid order_by field".into())), - }; - - let direction = if desc { "DESC" } else { "ASC" }; - let query = format!( - "SELECT p.id, p.title, p.abstract_text, p.publication_date, - p.source, p.source_identifier, p.pdf_url, p.doi - FROM papers p - ORDER BY {} {}", - order_clause, direction - ); - - self - .conn - .call(move |conn| { - let mut papers = Vec::new(); - let mut paper_stmt = conn.prepare(&query)?; - let mut author_stmt = conn.prepare_cached( - "SELECT name, affiliation, email - FROM authors - WHERE paper_id = ?", - )?; - - let paper_rows = paper_stmt.query_map([], |row| { - Ok(( - row.get::<_, i64>(0)?, // Get paper_id - Paper { - title: row.get(1)?, - abstract_text: row.get(2)?, - publication_date: row.get(3)?, - source: Source::from_str(&row.get::<_, String>(4)?).map_err(|e| { - rusqlite::Error::FromSqlConversionFailure( - 4, - rusqlite::types::Type::Text, - Box::new(e), - ) - })?, - source_identifier: row.get(5)?, - pdf_url: row.get(6)?, - doi: row.get(7)?, - authors: Vec::new(), - }, - )) - })?; - - for paper_result in paper_rows { - let (paper_id, mut paper) = paper_result?; - - // Get authors for this paper - let authors = author_stmt - .query_map([paper_id], |row| { - Ok(Author { - name: row.get(0)?, - affiliation: row.get(1)?, - email: row.get(2)?, - }) - })? - .collect::, _>>()?; - - paper.authors = authors; - papers.push(paper); - } - - Ok(papers) - }) - .await - .map_err(LearnerError::from) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - - /// Helper function to set up a test database - async fn setup_test_db() -> (Database, PathBuf, tempfile::TempDir) { - let dir = tempdir().unwrap(); - let path = dir.path().join("test.db"); - let db = Database::open(&path).await.unwrap(); - (db, path, dir) - } - - #[traced_test] - #[tokio::test] - async fn test_database_creation() { - let (_db, path, _dir) = setup_test_db().await; - - // Check that file exists - assert!(path.exists()); - } - - #[traced_test] - #[tokio::test] - async fn test_get_nonexistent_paper() { - let (db, _path, _dir) = setup_test_db().await; - - let result = db.get_paper_by_source_id(&Source::Arxiv, "nonexistent").await.unwrap(); - - assert!(result.is_none()); - } - - #[traced_test] - #[tokio::test] - async fn test_default_pdf_path() { - let path = Database::default_pdf_path(); - - // Should end with learner/papers - assert!(path.ends_with("learner/papers") || path.ends_with("learner\\papers")); - - // Should be rooted in a valid directory - assert!(path - .parent() - .unwrap() - .starts_with(dirs::document_dir().unwrap_or_else(|| PathBuf::from(".")))); - } - - #[traced_test] - #[tokio::test] - async fn test_config_operations() { - let (db, _path, _dir) = setup_test_db().await; - - // Test setting and getting a config value - db.set_config("test_key", "test_value").await.unwrap(); - let value = db.get_config("test_key").await.unwrap(); - assert_eq!(value, Some("test_value".to_string())); - - // Test getting non-existent config - let missing = db.get_config("nonexistent").await.unwrap(); - assert_eq!(missing, None); - - // Test updating existing config - db.set_config("test_key", "new_value").await.unwrap(); - let updated = db.get_config("test_key").await.unwrap(); - assert_eq!(updated, Some("new_value".to_string())); - } - - #[traced_test] - #[tokio::test] - async fn test_config_persistence() { - let dir = tempdir().unwrap(); - let db_path = dir.path().join("test.db"); - - // Create database and set config - { - let db = Database::open(&db_path).await.unwrap(); - db.set_config("pdf_dir", "/test/path").await.unwrap(); - } - - // Reopen database and verify config persists - { - let db = Database::open(&db_path).await.unwrap(); - let value = db.get_config("pdf_dir").await.unwrap(); - assert_eq!(value, Some("/test/path".to_string())); - } - } -} diff --git a/crates/learner/src/database/instruction/add.rs b/crates/learner/src/database/instruction/add.rs new file mode 100644 index 0000000..c818df5 --- /dev/null +++ b/crates/learner/src/database/instruction/add.rs @@ -0,0 +1,414 @@ +//! Database instruction implementation for adding papers and documents. +//! +//! This module provides functionality for adding papers and their associated documents +//! to the database. It supports several addition patterns: +//! +//! - Adding paper metadata only +//! - Adding complete papers with documents +//! - Batch addition of documents for existing papers +//! +//! The implementation emphasizes: +//! - Atomic transactions for data consistency +//! - Efficient batch processing +//! - Concurrent document downloads +//! - Duplicate handling +//! +//! # Examples +//! +//! ```no_run +//! use learner::{ +//! database::{Add, Database, Query}, +//! paper::Paper, +//! prelude::*, +//! }; +//! +//! # async fn example() -> Result<(), Box> { +//! let mut db = Database::open(Database::default_path()).await?; +//! +//! // Add just paper metadata +//! let paper = Paper::new("2301.07041").await?; +//! Add::paper(&paper).execute(&mut db).await?; +//! +//! // Add paper with document +//! Add::complete(&paper).execute(&mut db).await?; +//! +//! // Add documents for papers matching a query +//! let query = Query::by_author("Alice Researcher"); +//! Add::documents(query).execute(&mut db).await?; +//! # Ok(()) +//! # } +//! ``` + +use std::collections::HashSet; + +use futures::future::try_join_all; + +use super::*; + +// TODO (autoparallel): Would be good to have `Papers` and `Documents` and `Completes` instead, +// possibly, and just have a simple API for single paper calls that just dumps into the 3 variants. +/// Represents different types of additions to the database. +/// +/// This enum defines the supported addition operations, each handling a different +/// aspect of paper and document management: +/// +/// - Metadata-only additions +/// - Complete paper additions (metadata + document) +/// - Batch document additions for existing papers +#[derive(Debug)] +pub enum Addition<'a> { + /// Add just the paper metadata without associated documents + Paper(&'a Paper), + /// Add both paper metadata and download its associated document + Complete(&'a Paper), + /// Add documents for papers matching a specified query + Documents(Query<'a>), +} + +/// Database instruction for adding papers and documents. +/// +/// This struct implements the [`DatabaseInstruction`] trait to provide +/// paper and document addition functionality. It handles: +/// +/// - Paper metadata insertion +/// - Author information management +/// - Document downloading and storage +/// - Batch processing for multiple papers +/// +/// Operations are performed atomically using database transactions to +/// ensure consistency. +pub struct Add<'a> { + /// The type of addition operation to perform + addition: Addition<'a>, +} + +impl<'a> Add<'a> { + /// Creates an instruction to add paper metadata only. + /// + /// This method creates an addition that will store the paper's metadata + /// in the database without downloading or storing its associated document. + /// + /// # Arguments + /// + /// * `paper` - Reference to the paper to add + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Add; + /// # use learner::paper::Paper; + /// # async fn example() -> Result<(), Box> { + /// let paper = Paper::new("2301.07041").await?; + /// let instruction = Add::paper(&paper); + /// # Ok(()) + /// # } + /// ``` + pub fn paper(paper: &'a Paper) -> Self { Self { addition: Addition::Paper(paper) } } + + /// Creates an instruction to add a complete paper with its document. + /// + /// This method creates an addition that will: + /// 1. Store the paper's metadata + /// 2. Download the paper's document + /// 3. Store the document in the configured storage location + /// + /// # Arguments + /// + /// * `paper` - Reference to the paper to add + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Add; + /// # use learner::paper::Paper; + /// # async fn example() -> Result<(), Box> { + /// let paper = Paper::new("2301.07041").await?; + /// let instruction = Add::complete(&paper); + /// # Ok(()) + /// # } + /// ``` + pub fn complete(paper: &'a Paper) -> Self { Self { addition: Addition::Complete(paper) } } + + /// Creates an instruction to add documents for papers matching a query. + /// + /// This method supports batch document addition by: + /// 1. Finding papers matching the query + /// 2. Filtering out papers that already have documents + /// 3. Concurrently downloading missing documents + /// 4. Storing documents in the configured location + /// + /// # Arguments + /// + /// * `query` - Query to identify papers needing documents + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::{Add, Query}; + /// # async fn example() -> Result<(), Box> { + /// // Add documents for all papers by an author + /// let query = Query::by_author("Alice Researcher"); + /// let instruction = Add::documents(query); + /// + /// // Or add documents for papers matching a search + /// let query = Query::text("quantum computing"); + /// let instruction = Add::documents(query); + /// # Ok(()) + /// # } + /// ``` + pub fn documents(query: Query<'a>) -> Self { Self { addition: Addition::Documents(query) } } + + /// Converts a paper-only addition to a complete addition. + /// + /// This method allows for fluent conversion of a paper metadata addition + /// to include document download and storage. + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Add; + /// # use learner::paper::Paper; + /// # async fn example() -> Result<(), Box> { + /// let paper = Paper::new("2301.07041").await?; + /// let instruction = Add::paper(&paper).with_document(); + /// # Ok(()) + /// # } + /// ``` + pub fn with_document(self) -> Self { + match self.addition { + Addition::Paper(paper) => Self { addition: Addition::Complete(paper) }, + _ => self, + } + } + + /// Builds the SQL for inserting paper metadata. + fn build_paper_sql(paper: &Paper) -> (String, Vec>) { + ( + "INSERT INTO papers ( + title, abstract_text, publication_date, + source, source_identifier, pdf_url, doi + ) VALUES (?, ?, ?, ?, ?, ?, ?)" + .to_string(), + vec![ + Some(paper.title.clone()), + Some(paper.abstract_text.clone()), + Some(paper.publication_date.to_rfc3339()), + Some(paper.source.to_string()), + Some(paper.source_identifier.clone()), + paper.pdf_url.clone(), + paper.doi.clone(), + ], + ) + } + + /// Builds the SQL for inserting author information. + fn build_author_sql(author: &Author, paper: &Paper) -> (String, Vec>) { + ( + "INSERT INTO authors (paper_id, name, affiliation, email) + SELECT id, ?, ?, ? + FROM papers + WHERE source = ? AND source_identifier = ?" + .to_string(), + vec![ + Some(author.name.clone()), + author.affiliation.clone(), + author.email.clone(), + Some(paper.source.to_string()), + Some(paper.source_identifier.clone()), + ], + ) + } + + /// Builds the SQL for recording document storage information. + fn build_document_sql( + paper: &Paper, + storage_path: &Path, + filename: &Path, + ) -> (String, Vec>) { + ( + "INSERT INTO files (paper_id, path, filename, download_status) + SELECT p.id, ?, ?, 'Success' + FROM papers p + WHERE p.source = ? AND p.source_identifier = ?" + .to_string(), + vec![ + Some(storage_path.to_string_lossy().to_string()), + Some(filename.to_string_lossy().to_string()), + Some(paper.source.to_string()), + Some(paper.source_identifier.clone()), + ], + ) + } + + /// Builds the SQL for checking existing document records. + fn build_existing_docs_sql(papers: &[&Paper]) -> (String, Vec>) { + let mut params = Vec::new(); + let mut param_placeholders = Vec::new(); + + for paper in papers { + params.push(Some(paper.source.to_string())); + params.push(Some(paper.source_identifier.clone())); + param_placeholders.push("(? = p.source AND ? = p.source_identifier)"); + } + + ( + format!( + "SELECT p.source, p.source_identifier + FROM files f + JOIN papers p ON p.id = f.paper_id + WHERE f.download_status = 'Success' + AND ({})", + param_placeholders.join(" OR ") + ), + params, + ) + } +} + +#[async_trait] +impl DatabaseInstruction for Add<'_> { + type Output = Vec; + + async fn execute(&self, db: &mut Database) -> Result { + match &self.addition { + Addition::Paper(paper) => { + // Check for existing paper + if Query::by_source(paper.source, &paper.source_identifier) + .execute(db) + .await? + .into_iter() + .next() + .is_some() + { + return Err(LearnerError::DatabaseDuplicatePaper(paper.title.clone())); + } + + let (paper_sql, paper_params) = Self::build_paper_sql(paper); + let author_statements: Vec<_> = + paper.authors.iter().map(|author| Self::build_author_sql(author, paper)).collect(); + + db.conn + .call(move |conn| { + let tx = conn.transaction()?; + tx.execute(&paper_sql, params_from_iter(paper_params))?; + + for (author_sql, author_params) in author_statements { + tx.execute(&author_sql, params_from_iter(author_params))?; + } + + tx.commit()?; + Ok(()) + }) + .await?; + + Ok(vec![(*paper).clone()]) + }, + + Addition::Complete(paper) => { + // Add paper first + if let Err(LearnerError::DatabaseDuplicatePaper(_)) = Add::paper(paper).execute(db).await { + warn!( + "Tried to add complete paper when paper existed in database already, attempting to \ + add only the document!" + ) + }; + + // Add document + let storage_path = db.get_storage_path().await?; + let filename = paper.download_pdf(&storage_path).await?; + + let (doc_sql, doc_params) = Self::build_document_sql(paper, &storage_path, &filename); + + db.conn + .call(move |conn| { + let tx = conn.transaction()?; + tx.execute(&doc_sql, params_from_iter(doc_params))?; + tx.commit()?; + Ok(()) + }) + .await?; + + Ok(vec![(*paper).clone()]) + }, + + Addition::Documents(query) => { + let papers = query.execute(db).await?; + if papers.is_empty() { + return Ok(Vec::new()); + } + + let storage_path = db.get_storage_path().await?; + let mut added = Vec::new(); + + // Process papers in batches + for chunk in papers.chunks(10) { + // Check which papers already have documents + let paper_refs: Vec<_> = chunk.iter().collect(); + let (check_sql, check_params) = Self::build_existing_docs_sql(&paper_refs); + + let existing_docs: HashSet<(String, String)> = db + .conn + .call(move |conn| { + let mut docs = HashSet::new(); + let mut stmt = conn.prepare_cached(&check_sql)?; + let mut rows = stmt.query(params_from_iter(check_params))?; + + while let Some(row) = rows.next()? { + docs.insert((row.get::<_, String>(0)?, row.get::<_, String>(1)?)); + } + Ok(docs) + }) + .await?; + + // Create future for each paper that needs downloading + let download_futures: Vec<_> = chunk + .iter() + .filter(|paper| { + let key = (paper.source.to_string(), paper.source_identifier.clone()); + !existing_docs.contains(&key) + }) + .map(|paper| { + let paper = paper.clone(); + let storage_path = storage_path.clone(); + async move { paper.download_pdf(&storage_path).await.map(|f| (paper, f)) } + }) + .collect(); + + if download_futures.is_empty() { + continue; + } + + // Download PDFs concurrently and collect results + let results = try_join_all(download_futures).await?; + + // Prepare batch insert for successful downloads + let mut insert_sqls = Vec::new(); + let mut insert_params = Vec::new(); + + for (paper, filename) in results { + let (sql, params) = Self::build_document_sql(&paper, &storage_path, &filename); + insert_sqls.push(sql); + insert_params.extend(params); + added.push(paper); + } + + if !insert_sqls.is_empty() { + // Execute batch insert + db.conn + .call(move |conn| { + let tx = conn.transaction()?; + for (sql, params) in insert_sqls.iter().zip(insert_params.chunks(4)) { + tx.execute(sql, params_from_iter(params))?; + } + tx.commit()?; + Ok(()) + }) + .await?; + } + } + + Ok(added) + }, + } + } +} diff --git a/crates/learner/src/database/instruction/mod.rs b/crates/learner/src/database/instruction/mod.rs new file mode 100644 index 0000000..c0bc956 --- /dev/null +++ b/crates/learner/src/database/instruction/mod.rs @@ -0,0 +1,168 @@ +//! Database instruction implementations for structured database operations. +//! +//! This module provides a trait-based abstraction for database operations using the +//! Command pattern. This design allows for: +//! +//! - Type-safe database operations +//! - Composable and reusable commands +//! - Clear separation of operation logic +//! - Consistent error handling +//! +//! # Architecture +//! +//! The module is organized around three main operation types: +//! +//! - [`query`] - Read operations for searching and retrieving papers +//! - [`add`] - Write operations for adding papers and documents +//! - [`remove`] - Delete operations for removing papers from the database +//! +//! Each operation type implements the [`DatabaseInstruction`] trait, providing +//! a consistent interface while allowing for operation-specific behavior. +//! +//! # Usage +//! +//! Operations are constructed as instructions and then executed against a database: +//! +//! ```no_run +//! use learner::{ +//! database::{Add, Database, Query, Remove}, +//! paper::Paper, +//! prelude::*, +//! }; +//! +//! # async fn example() -> Result<(), Box> { +//! let mut db = Database::open("papers.db").await?; +//! +//! // Query papers +//! let papers = Query::text("quantum computing").execute(&mut db).await?; +//! +//! // Add a new paper +//! let paper = Paper::new("2301.07041").await?; +//! Add::paper(&paper).execute(&mut db).await?; +//! +//! // Remove papers by author +//! Remove::by_author("Alice Researcher").execute(&mut db).await?; +//! # Ok(()) +//! # } +//! ``` + +use super::*; + +pub mod add; +pub mod query; +pub mod remove; + +use async_trait::async_trait; +use rusqlite::{params_from_iter, ToSql}; + +use self::query::Query; + +/// Trait for implementing type-safe database operations. +/// +/// This trait defines the core interface for the Command pattern used in database +/// operations. Each implementation represents a specific operation (like querying, +/// adding, or removing papers) and encapsulates its own: +/// +/// - SQL generation and execution +/// - Parameter handling +/// - Result type specification +/// - Error handling +/// +/// The trait is async to support non-blocking database operations while maintaining +/// proper connection management. +/// +/// # Type Parameters +/// +/// * `Output` - The type returned by executing this instruction. Common types include: +/// - `Vec` for query operations +/// - `()` for operations that don't return data +/// - Custom types for specialized operations +/// +/// # Implementation Notes +/// +/// When implementing this trait: +/// - Keep SQL generation and execution within the implementation +/// - Use proper parameter binding for SQL injection prevention +/// - Handle errors appropriately and convert to [`LearnerError`] +/// - Consider optimizing repeated operations with prepared statements +/// +/// # Examples +/// +/// Querying papers with different criteria: +/// +/// ```no_run +/// # use learner::database::{Database, DatabaseInstruction, Query}; +/// # async fn example() -> Result<(), Box> { +/// let mut db = Database::open("papers.db").await?; +/// +/// // Full-text search +/// let papers = Query::text("neural networks").execute(&mut db).await?; +/// +/// // Search by author +/// let papers = Query::by_author("Alice Researcher").execute(&mut db).await?; +/// +/// // Search by publication date +/// use chrono::{DateTime, Utc}; +/// let papers = +/// Query::before_date(DateTime::parse_from_rfc3339("2024-01-01T00:00:00Z")?.with_timezone(&Utc)) +/// .execute(&mut db) +/// .await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// Implementing a custom instruction: +/// +/// ```no_run +/// # use learner::{database::{Database, DatabaseInstruction}, error::LearnerError}; +/// # use async_trait::async_trait; +/// struct CountPapers; +/// +/// #[async_trait] +/// impl DatabaseInstruction for CountPapers { +/// type Output = i64; +/// +/// async fn execute(&self, db: &mut Database) -> std::result::Result { +/// Ok( +/// db.conn +/// .call(|conn| { +/// conn.query_row("SELECT COUNT(*) FROM papers", [], |row| row.get(0)).map_err(Into::into) +/// }) +/// .await?, +/// ) +/// } +/// } +/// # type Result = std::result::Result>; +/// ``` +#[async_trait] +pub trait DatabaseInstruction { + /// The type returned by executing this instruction. + type Output; + + // TODO (autoparallel): It may honestly be worth having two traits -- one that takes &mut db and + // another that takes &db so you don't need to have shared mutability access + /// Executes the instruction against a database connection. + /// + /// This method performs the actual database operation, managing: + /// - SQL execution + /// - Parameter binding + /// - Result processing + /// - Error handling + /// + /// # Arguments + /// + /// * `db` - Mutable reference to the database connection + /// + /// # Returns + /// + /// Returns a `Result` containing either: + /// - The operation's output of type `Self::Output` + /// - A [`LearnerError`] if the operation fails + /// + /// # Notes + /// + /// The mutable database reference is required for operations that modify + /// the database. A future enhancement might split this into separate traits + /// for read-only and read-write operations. + async fn execute(&self, db: &mut Database) -> Result; +} diff --git a/crates/learner/src/database/instruction/query.rs b/crates/learner/src/database/instruction/query.rs new file mode 100644 index 0000000..98795ad --- /dev/null +++ b/crates/learner/src/database/instruction/query.rs @@ -0,0 +1,428 @@ +//! Query instruction implementation for retrieving papers from the database. +//! +//! This module provides a flexible query system for searching and retrieving papers +//! using various criteria. It supports: +//! +//! - Full-text search across titles and abstracts +//! - Source-specific identifier lookups +//! - Author name searches +//! - Publication date filtering +//! - Custom result ordering +//! +//! The implementation prioritizes: +//! - Efficient query execution using prepared statements +//! - SQLite full-text search integration +//! - Type-safe query construction +//! - Flexible result ordering +//! +//! # Examples +//! +//! ```no_run +//! use learner::{ +//! database::{Database, OrderField, Query}, +//! paper::Source, +//! prelude::*, +//! }; +//! +//! # async fn example() -> Result<(), Box> { +//! let mut db = Database::open("papers.db").await?; +//! +//! // Full-text search +//! let papers = Query::text("quantum computing") +//! .order_by(OrderField::PublicationDate) +//! .descending() +//! .execute(&mut db) +//! .await?; +//! +//! // Search by author +//! let papers = Query::by_author("Alice Researcher").execute(&mut db).await?; +//! +//! // Lookup by source identifier +//! let papers = Query::by_source(Source::Arxiv, "2301.07041").execute(&mut db).await?; +//! # Ok(()) +//! # } +//! ``` + +use super::*; + +/// Represents different ways to query papers in the database. +/// +/// This enum defines the supported search criteria for paper queries, +/// each providing different ways to locate papers in the database: +/// +/// - Text-based searching using SQLite FTS +/// - Direct lookups by source identifiers +/// - Author-based searches +/// - Publication date filtering +/// - Complete collection retrieval +#[derive(Debug)] +pub enum QueryCriteria<'a> { + /// Full-text search across titles and abstracts using SQLite FTS + Text(&'a str), + /// Direct lookup by source system and identifier + SourceId { + /// The source system (e.g., arXiv, DOI) + source: Source, + /// The source-specific identifier + identifier: &'a str, + }, + /// Search by author name with partial matching + Author(&'a str), + /// Retrieve the complete paper collection + All, + /// Filter papers by publication date + BeforeDate(DateTime), +} + +/// Available fields for ordering query results. +/// +/// This enum defines the paper attributes that can be used for +/// sorting query results. Each field maps to specific database +/// columns and handles appropriate comparison logic. +#[derive(Debug, Clone, Copy)] +pub enum OrderField { + /// Order alphabetically by paper title + Title, + /// Order chronologically by publication date + PublicationDate, + /// Order by source system and identifier + Source, +} + +impl OrderField { + /// Converts the ordering field to its SQL representation. + /// + /// Returns the appropriate SQL column names for ORDER BY clauses, + /// handling both single-column and multi-column ordering. + fn as_sql_str(&self) -> &'static str { + match self { + OrderField::Title => "title", + OrderField::PublicationDate => "publication_date", + OrderField::Source => "source, source_identifier", + } + } +} + +/// A query builder for retrieving papers from the database. +/// +/// This struct provides a fluent interface for constructing paper queries, +/// supporting various search criteria and result ordering options. It handles: +/// +/// - Query criteria specification +/// - Result ordering configuration +/// - SQL generation and execution +/// - Paper reconstruction from rows +#[derive(Debug)] +pub struct Query<'a> { + /// The search criteria to apply + criteria: QueryCriteria<'a>, + /// Optional field to sort results by + order_by: Option, + /// Whether to sort in descending order + descending: bool, +} + +impl<'a> Query<'a> { + /// Creates a new query with the given criteria. + /// + /// # Arguments + /// + /// * `criteria` - The search criteria to use + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::{Query, QueryCriteria}; + /// let query = Query::new(QueryCriteria::All); + /// ``` + pub fn new(criteria: QueryCriteria<'a>) -> Self { + Self { criteria, order_by: None, descending: false } + } + + /// Creates a full-text search query. + /// + /// Searches through paper titles and abstracts using SQLite's FTS5 + /// full-text search engine with wildcard matching. + /// + /// # Arguments + /// + /// * `query` - The text to search for + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Query; + /// let query = Query::text("quantum computing"); + /// ``` + pub fn text(query: &'a str) -> Self { Self::new(QueryCriteria::Text(query)) } + + /// Creates a query to find a specific paper. + /// + /// # Arguments + /// + /// * `paper` - The paper whose source and identifier should be matched + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Query; + /// # use learner::paper::Paper; + /// # async fn example() -> Result<(), Box> { + /// let paper = Paper::new("2301.07041").await?; + /// let query = Query::by_paper(&paper); + /// # Ok(()) + /// # } + /// ``` + pub fn by_paper(paper: &'a Paper) -> Self { + Self::new(QueryCriteria::SourceId { + source: paper.source, + identifier: &paper.source_identifier, + }) + } + + /// Creates a query to find a paper by its source and identifier. + /// + /// # Arguments + /// + /// * `source` - The paper source (arXiv, DOI, etc.) + /// * `identifier` - The source-specific identifier + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Query; + /// # use learner::paper::Source; + /// let query = Query::by_source(Source::Arxiv, "2301.07041"); + /// ``` + pub fn by_source(source: Source, identifier: &'a str) -> Self { + Self::new(QueryCriteria::SourceId { source, identifier }) + } + + /// Creates a query to find papers by author name. + /// + /// Performs a partial match on author names, allowing for flexible + /// name searches. + /// + /// # Arguments + /// + /// * `name` - The author name to search for + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Query; + /// let query = Query::by_author("Alice Researcher"); + /// ``` + pub fn by_author(name: &'a str) -> Self { Self::new(QueryCriteria::Author(name)) } + + /// Creates a query that returns all papers. + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Query; + /// let query = Query::list_all(); + /// ``` + pub fn list_all() -> Self { Self::new(QueryCriteria::All) } + + /// Creates a query for papers published before a specific date. + /// + /// # Arguments + /// + /// * `date` - The cutoff date for publication + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Query; + /// # use chrono::{DateTime, Utc}; + /// let date = DateTime::parse_from_rfc3339("2024-01-01T00:00:00Z").unwrap().with_timezone(&Utc); + /// let query = Query::before_date(date); + /// ``` + pub fn before_date(date: DateTime) -> Self { Self::new(QueryCriteria::BeforeDate(date)) } + + /// Sets the field to order results by. + /// + /// # Arguments + /// + /// * `field` - The field to sort by + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::{Query, OrderField}; + /// let query = Query::list_all().order_by(OrderField::PublicationDate); + /// ``` + pub fn order_by(mut self, field: OrderField) -> Self { + self.order_by = Some(field); + self + } + + /// Sets the order to descending (default is ascending). + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::{Query, OrderField}; + /// let query = Query::list_all().order_by(OrderField::PublicationDate).descending(); + /// ``` + pub fn descending(mut self) -> Self { + self.descending = true; + self + } + + /// Builds the SQL for retrieving paper IDs based on search criteria. + fn build_criteria_sql(&self) -> (String, Vec) { + match &self.criteria { + QueryCriteria::Text(query) => ( + "SELECT p.id + FROM papers p + JOIN papers_fts f ON p.id = f.rowid + WHERE papers_fts MATCH ?1 || '*' + ORDER BY rank" + .into(), + vec![(*query).to_string()], + ), + QueryCriteria::SourceId { source, identifier } => ( + "SELECT id FROM papers + WHERE source = ?1 AND source_identifier = ?2" + .into(), + vec![source.to_string(), (*identifier).to_string()], + ), + QueryCriteria::Author(name) => ( + "SELECT DISTINCT p.id + FROM papers p + JOIN authors a ON p.id = a.paper_id + WHERE a.name LIKE ?1" + .into(), + vec![format!("%{}%", name)], + ), + QueryCriteria::All => ("SELECT id FROM papers".into(), Vec::new()), + QueryCriteria::BeforeDate(date) => ( + "SELECT id FROM papers + WHERE publication_date < ?1" + .into(), + vec![date.to_rfc3339()], + ), + } + } + + /// Builds the SQL for retrieving complete paper data. + fn build_paper_sql(&self) -> String { + let base = "SELECT title, abstract_text, publication_date, + source, source_identifier, pdf_url, doi + FROM papers + WHERE id = ?1"; + + if let Some(order_field) = &self.order_by { + let direction = if self.descending { "DESC" } else { "ASC" }; + format!("{} ORDER BY {} {}", base, order_field.as_sql_str(), direction) + } else { + base.to_string() + } + } +} + +#[async_trait] +impl DatabaseInstruction for Query<'_> { + type Output = Vec; + + async fn execute(&self, db: &mut Database) -> Result { + let (criteria_sql, params) = self.build_criteria_sql(); + let paper_sql = self.build_paper_sql(); + let order_by = self.order_by; + let descending = self.descending; + + let papers = db + .conn + .call(move |conn| { + let mut papers = Vec::new(); + let tx = conn.transaction()?; + + // Get paper IDs based on search criteria + let paper_ids = { + let mut stmt = tx.prepare_cached(&criteria_sql)?; + let mut rows = stmt.query(params_from_iter(params))?; + let mut ids = Vec::new(); + while let Some(row) = rows.next()? { + ids.push(row.get::<_, i64>(0)?); + } + ids + }; + + // Fetch complete paper data for each ID + for paper_id in paper_ids { + let mut paper_stmt = tx.prepare_cached(&paper_sql)?; + let paper = paper_stmt.query_row([paper_id], |row| { + Ok(Paper { + title: row.get(0)?, + abstract_text: row.get(1)?, + publication_date: DateTime::parse_from_rfc3339(&row.get::<_, String>(2)?) + .map(|dt| dt.with_timezone(&Utc)) + .map_err(|e| { + rusqlite::Error::FromSqlConversionFailure( + 2, + rusqlite::types::Type::Text, + Box::new(e), + ) + })?, + source: Source::from_str(&row.get::<_, String>(3)?).map_err(|e| { + rusqlite::Error::FromSqlConversionFailure( + 3, + rusqlite::types::Type::Text, + Box::new(e), + ) + })?, + source_identifier: row.get(4)?, + pdf_url: row.get(5)?, + doi: row.get(6)?, + authors: Vec::new(), + }) + })?; + + // Get authors for this paper + let mut author_stmt = tx.prepare_cached( + "SELECT name, affiliation, email + FROM authors + WHERE paper_id = ?", + )?; + + let authors = author_stmt + .query_map([paper_id], |row| { + Ok(Author { + name: row.get(0)?, + affiliation: row.get(1)?, + email: row.get(2)?, + }) + })? + .collect::>>()?; + + let mut paper = paper; + paper.authors = authors; + papers.push(paper); + } + + // Sort if needed + if let Some(order_field) = order_by { + papers.sort_by(|a, b| { + let cmp = match order_field { + OrderField::Title => a.title.cmp(&b.title), + OrderField::PublicationDate => a.publication_date.cmp(&b.publication_date), + OrderField::Source => (a.source.to_string(), &a.source_identifier) + .cmp(&(b.source.to_string(), &b.source_identifier)), + }; + if descending { + cmp.reverse() + } else { + cmp + } + }); + } + + Ok(papers) + }) + .await?; + + Ok(papers) + } +} diff --git a/crates/learner/src/database/instruction/remove.rs b/crates/learner/src/database/instruction/remove.rs new file mode 100644 index 0000000..d063e28 --- /dev/null +++ b/crates/learner/src/database/instruction/remove.rs @@ -0,0 +1,250 @@ +//! Remove instruction implementation for paper deletion from the database. +//! +//! This module provides functionality for safely removing papers and their associated +//! data from the database. It supports: +//! +//! - Query-based paper removal +//! - Dry run simulation +//! - Cascade deletion of related data +//! - Atomic transactions +//! +//! The implementation emphasizes: +//! - Safe deletion with transaction support +//! - Cascading removals across related tables +//! - Validation before deletion +//! - Preview capabilities through dry runs +//! +//! # Examples +//! +//! ```no_run +//! use learner::{ +//! database::{Database, Query, Remove}, +//! paper::Source, +//! prelude::*, +//! }; +//! +//! # async fn example() -> Result<(), Box> { +//! let mut db = Database::open("papers.db").await?; +//! +//! // Remove a specific paper +//! Remove::by_source(Source::Arxiv, "2301.07041").execute(&mut db).await?; +//! +//! // Preview deletion with dry run +//! let papers = Remove::by_author("Alice Researcher").dry_run().execute(&mut db).await?; +//! +//! println!("Would remove {} papers", papers.len()); +//! # Ok(()) +//! # } +//! ``` + +use super::*; + +/// Configuration options for paper removal operations. +/// +/// This struct allows customization of how the remove operation +/// behaves, particularly useful for validation and testing. +#[derive(Default)] +pub struct RemoveOptions { + /// When true, simulates the removal operation without modifying the database. + /// + /// This is useful for: + /// - Previewing which papers would be removed + /// - Validating removal queries + /// - Testing removal logic safely + pub dry_run: bool, +} + +/// Instruction for removing papers from the database. +/// +/// This struct implements the [`DatabaseInstruction`] trait to provide +/// paper removal functionality. It handles: +/// +/// - Paper identification through queries +/// - Related data cleanup (authors, files) +/// - Transaction management +/// - Dry run simulation +pub struct Remove<'a> { + /// The query identifying papers to remove + query: Query<'a>, + /// Configuration options for the removal + options: RemoveOptions, +} + +impl<'a> Remove<'a> { + /// Creates a remove instruction from an existing query. + /// + /// This method allows any query to be converted into a remove operation, + /// providing maximum flexibility in identifying papers to remove. + /// + /// # Arguments + /// + /// * `query` - The query that identifies papers to remove + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::{Remove, Query}; + /// // Remove papers matching a text search + /// let query = Query::text("quantum computing"); + /// let remove = Remove::from_query(query); + /// + /// // Remove papers before a date + /// use chrono::{DateTime, Utc}; + /// let date = DateTime::parse_from_rfc3339("2020-01-01T00:00:00Z").unwrap().with_timezone(&Utc); + /// let query = Query::before_date(date); + /// let remove = Remove::from_query(query); + /// ``` + pub fn from_query(query: Query<'a>) -> Self { Self { query, options: RemoveOptions::default() } } + + /// Creates a remove instruction for a specific paper by its source and identifier. + /// + /// This is a convenience method for the common case of removing + /// a single paper identified by its source system and ID. + /// + /// # Arguments + /// + /// * `source` - The paper's source system (arXiv, DOI, etc.) + /// * `identifier` - The source-specific identifier + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Remove; + /// # use learner::paper::Source; + /// // Remove an arXiv paper + /// let remove = Remove::by_source(Source::Arxiv, "2301.07041"); + /// + /// // Remove a DOI paper + /// let remove = Remove::by_source(Source::DOI, "10.1145/1327452.1327492"); + /// ``` + pub fn by_source(source: Source, identifier: &'a str) -> Self { + Self::from_query(Query::by_source(source, identifier)) + } + + /// Creates a remove instruction for all papers by a specific author. + /// + /// This method provides a way to remove all papers associated with + /// a particular author name. It performs partial matching on the name. + /// + /// # Arguments + /// + /// * `name` - The author name to match + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Remove; + /// // Remove all papers by an author + /// let remove = Remove::by_author("Alice Researcher"); + /// ``` + pub fn by_author(name: &'a str) -> Self { Self::from_query(Query::by_author(name)) } + + /// Enables dry run mode for the remove operation. + /// + /// In dry run mode, the operation will: + /// - Query papers that would be removed + /// - Return the list of papers + /// - Not modify the database + /// + /// This is useful for: + /// - Previewing removal operations + /// - Validating queries + /// - Testing removal logic + /// + /// # Examples + /// + /// ```no_run + /// # use learner::{database::Remove, prelude::*}; + /// # use learner::paper::Source; + /// # async fn example() -> Result<(), Box> { + /// # let mut db = learner::database::Database::open("papers.db").await?; + /// // Preview papers that would be removed + /// let papers = Remove::by_author("Alice Researcher").dry_run().execute(&mut db).await?; + /// + /// println!("Would remove {} papers", papers.len()); + /// # Ok(()) + /// # } + /// ``` + pub fn dry_run(mut self) -> Self { + self.options.dry_run = true; + self + } + + /// Builds SQL to retrieve paper IDs for removal. + /// + /// Generates the SQL and parameters needed to find database IDs + /// for papers matching the removal criteria. + fn build_paper_ids_sql(paper: &Paper) -> (String, Vec>) { + ("SELECT id FROM papers WHERE source = ? AND source_identifier = ?".to_string(), vec![ + Some(paper.source.to_string()), + Some(paper.source_identifier.clone()), + ]) + } + + /// Builds SQL to remove papers and all related data. + /// + /// Generates cascading DELETE statements to remove papers and their + /// associated data (authors, files) in the correct order to maintain + /// referential integrity. + fn build_remove_sql(ids: &[i64]) -> (String, Vec>) { + let ids_str = ids.iter().map(|id| id.to_string()).collect::>().join(","); + + ( + format!( + "DELETE FROM authors WHERE paper_id IN ({0}); + DELETE FROM files WHERE paper_id IN ({0}); + DELETE FROM papers WHERE id IN ({0});", + ids_str + ), + Vec::new(), // No params needed since IDs are embedded in SQL + ) + } +} + +#[async_trait] +impl DatabaseInstruction for Remove<'_> { + type Output = Vec; + + async fn execute(&self, db: &mut Database) -> Result { + // Use Query to find the papers to remove + let papers = self.query.execute(db).await?; + + if !self.options.dry_run && !papers.is_empty() { + // Collect all paper IDs + let papers_clone = papers.clone(); + let ids: Vec = db + .conn + .call(move |conn| { + let mut ids = Vec::new(); + let tx = conn.transaction()?; + + for paper in &papers_clone { + let (sql, params) = Self::build_paper_ids_sql(paper); + if let Ok(id) = tx.query_row(&sql, params_from_iter(params), |row| row.get(0)) { + ids.push(id); + } + } + + tx.commit()?; + Ok(ids) + }) + .await?; + + if !ids.is_empty() { + // Remove the papers and their related data + let (remove_sql, _) = Self::build_remove_sql(&ids); + + db.conn + .call(move |conn| { + let tx = conn.transaction()?; + tx.execute_batch(&remove_sql)?; + tx.commit()?; + Ok(()) + }) + .await?; + } + } + + Ok(papers) + } +} diff --git a/crates/learner/src/database/mod.rs b/crates/learner/src/database/mod.rs new file mode 100644 index 0000000..f934485 --- /dev/null +++ b/crates/learner/src/database/mod.rs @@ -0,0 +1,384 @@ +//! Database management and operations for academic paper metadata. +//! +//! This module provides a flexible SQLite-based storage system for managing academic paper +//! metadata and references while allowing users to maintain control over how and where their +//! documents are stored. The database tracks: +//! +//! - Paper metadata (title, authors, abstract, publication date) +//! - Source information (arXiv, DOI, IACR) +//! - Document storage locations +//! - Full-text search capabilities +//! +//! The design emphasizes: +//! - User control over data storage locations +//! - Flexible integration with external PDF viewers and tools +//! - Efficient querying and organization of paper metadata +//! - Separation of metadata from document storage +//! +//! # Architecture +//! +//! The database module uses a command pattern through the [`DatabaseInstruction`] trait, +//! allowing for type-safe and composable database operations. Common operations are +//! implemented as distinct instruction types: +//! +//! - [`Query`] - For searching and retrieving papers +//! - [`Add`] - For adding new papers and documents +//! - [`Remove`] - For removing papers from the database +//! +//! # Examples +//! +//! ```no_run +//! use learner::{ +//! database::{Add, Database, Query}, +//! paper::Paper, +//! prelude::*, +//! }; +//! +//! # async fn example() -> Result<(), Box> { +//! // Open database at default location +//! let mut db = Database::open(Database::default_path()).await?; +//! +//! // Add a paper +//! let paper = Paper::new("2301.07041").await?; +//! Add::paper(&paper).execute(&mut db).await?; +//! +//! // Search for papers about neural networks +//! let papers = Query::text("neural networks").execute(&mut db).await?; +//! +//! // Customize document storage location +//! db.set_storage_path("~/Documents/research/papers").await?; +//! # Ok(()) +//! # } +//! ``` + +use tokio_rusqlite::Connection; + +use super::*; + +mod instruction; +// pub mod models; +#[cfg(test)] mod tests; + +pub use self::instruction::{ + add::Add, + query::{OrderField, Query, QueryCriteria}, + remove::Remove, + DatabaseInstruction, +}; + +/// Main database connection handler for the paper management system. +/// +/// The `Database` struct provides the primary interface for interacting with the SQLite +/// database that stores paper metadata and document references. It handles: +/// +/// - Database initialization and schema management +/// - Storage path configuration for documents +/// - Connection management for async database operations +/// +/// The database is designed to separate metadata storage (managed by this system) +/// from document storage (which can be managed by external tools), allowing users +/// to maintain their preferred document organization while benefiting from the +/// metadata management features. +pub struct Database { + /// Active connection to the SQLite database + pub conn: Connection, +} + +impl Database { + /// Opens an existing database or creates a new one at the specified path. + /// + /// This method performs complete database initialization: + /// 1. Creates parent directories if they don't exist + /// 2. Initializes the SQLite database file + /// 3. Applies schema migrations + /// 4. Sets up full-text search indexes for paper metadata + /// 5. Configures default storage paths if not already set + /// + /// # Arguments + /// + /// * `path` - Path where the database file should be created or opened. This can be: + /// - An absolute path to a specific location + /// - A relative path from the current directory + /// - The result of [`Database::default_path()`] for platform-specific default location + /// + /// # Returns + /// + /// Returns a [`Result`] containing either: + /// - A [`Database`] handle ready for operations + /// - A [`LearnerError`] if initialization fails + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Database; + /// # async fn example() -> Result<(), Box> { + /// // Use platform-specific default location + /// let db = Database::open(Database::default_path()).await?; + /// + /// // Or specify a custom location + /// let db = Database::open("/path/to/papers.db").await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn open(path: impl AsRef) -> Result { + // Create parent directories if needed + if let Some(parent) = path.as_ref().parent() { + std::fs::create_dir_all(parent)?; + } + + let conn = Connection::open(path.as_ref()).await?; + + // Initialize schema + conn + .call(|conn| { + Ok(conn.execute_batch(include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/migrations/init.sql" + )))?) + }) + .await?; + + let db = Self { conn }; + + // Check if storage path is set, if not, set default + if db.get_storage_path().await.is_err() { + db.set_storage_path(Self::default_storage_path()).await?; + } + + Ok(db) + } + + /// Gets the configured storage path for document files. + /// + /// The storage path determines where document files (like PDFs) will be saved + /// when downloaded through the system. This path is stored in the database + /// configuration and can be modified using [`Database::set_storage_path()`]. + /// + /// # Returns + /// + /// Returns a `Result` containing either: + /// - The configured [`PathBuf`] for document storage + /// - A [`LearnerError`] if the path cannot be retrieved + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Database; + /// # async fn example() -> Result<(), Box> { + /// let db = Database::open(Database::default_path()).await?; + /// let storage_path = db.get_storage_path().await?; + /// println!("Documents are stored in: {}", storage_path.display()); + /// # Ok(()) + /// # } + /// ``` + pub async fn get_storage_path(&self) -> Result { + Ok( + self + .conn + .call(|conn| { + Ok( + conn + .prepare_cached("SELECT value FROM config WHERE key = 'storage_path'")? + .query_row([], |row| Ok(PathBuf::from(row.get::<_, String>(0)?)))?, + ) + }) + .await?, + ) + } + + /// Sets the storage path for document files, validating that the path is usable. + /// + /// This method configures where document files (like PDFs) will be stored when + /// downloaded through the system. It performs extensive validation to ensure the + /// path is usable and accessible: + /// + /// - Verifies the path exists or can be created + /// - Confirms the filesystem is writable + /// - Validates sufficient permissions exist + /// - Ensures the path is absolute for reliability + /// + /// When changing the storage path, existing documents are not automatically moved. + /// Users should manually migrate their documents if needed. + /// + /// # Arguments + /// + /// * `path` - The path where document files should be stored. Must be an absolute path. + /// + /// # Returns + /// + /// Returns a `Result` containing: + /// - `Ok(())` if the path is valid and has been configured + /// - `Err(LearnerError)` if the path is invalid or cannot be used + /// + /// # Errors + /// + /// This function will return an error if: + /// - The path is not absolute + /// - The path cannot be created + /// - The filesystem is read-only + /// - Insufficient permissions exist + /// + /// # Examples + /// + /// ```no_run + /// # use learner::database::Database; + /// # async fn example() -> Result<(), Box> { + /// let db = Database::open(Database::default_path()).await?; + /// + /// // Set custom storage location + /// db.set_storage_path("/data/papers").await?; + /// + /// // Or use home directory + /// db.set_storage_path("~/Documents/papers").await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn set_storage_path(&self, path: impl AsRef) -> Result<()> { + let original_path_result = self.get_storage_path().await; + let path = path.as_ref(); + + // Convert relative paths to absolute using current working directory + let absolute_path = + if !path.is_absolute() { std::env::current_dir()?.join(path) } else { path.to_path_buf() }; + + // Create a test file to verify write permissions + let test_file = absolute_path.join(".learner_write_test"); + + // First try to create the directory structure + match std::fs::create_dir_all(&absolute_path) { + Ok(_) => { + // Rest of the code remains the same, but use absolute_path instead of path + match std::fs::write(&test_file, b"test") { + Ok(_) => { + // Clean up test file + let _ = std::fs::remove_file(&test_file); + }, + Err(e) => { + return Err(match e.kind() { + std::io::ErrorKind::PermissionDenied => LearnerError::Path(std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "Insufficient permissions to write to storage directory", + )), + std::io::ErrorKind::ReadOnlyFilesystem => LearnerError::Path(std::io::Error::new( + std::io::ErrorKind::ReadOnlyFilesystem, + "Storage location is on a read-only filesystem", + )), + _ => LearnerError::Path(e), + }); + }, + } + }, + Err(e) => { + return Err(LearnerError::Path(std::io::Error::new( + e.kind(), + format!("Failed to create storage directory: {}", e), + ))); + }, + } + + // If we get here, the path is valid and writable + let path_str = absolute_path.to_string_lossy().to_string(); + + self + .conn + .call(move |conn| { + Ok( + conn + .execute("INSERT OR REPLACE INTO config (key, value) VALUES ('storage_path', ?1)", [ + path_str, + ])?, + ) + }) + .await?; + + if let Ok(original_path) = original_path_result { + warn!( + "Original storage path was {:?}, set a new path to {:?}. Please be careful to check that \ + your documents have been moved or that you intended to do this operation!", + original_path, absolute_path + ); + } + + Ok(()) + } + + /// Returns the platform-specific default path for the database file. + /// + /// This method provides a sensible default location for the database file + /// following platform conventions: + /// + /// - Unix: `~/.local/share/learner/learner.db` + /// - macOS: `~/Library/Application Support/learner/learner.db` + /// - Windows: `%APPDATA%\learner\learner.db` + /// - Fallback: `./learner.db` in the current directory + /// + /// # Returns + /// + /// Returns a [`PathBuf`] pointing to the default database location for the + /// current platform. + /// + /// # Examples + /// + /// ```no_run + /// use learner::database::Database; + /// + /// let path = Database::default_path(); + /// println!("Default database location: {}", path.display()); + /// ``` + pub fn default_path() -> PathBuf { + dirs::data_dir().unwrap_or_else(|| PathBuf::from(".")).join("learner").join("learner.db") + } + + /// Returns the platform-specific default path for document storage. + /// + /// This method provides a sensible default location for storing document files + /// following platform conventions. The returned path is always absolute and follows + /// these patterns: + /// + /// - Unix: `~/Documents/learner/papers` + /// - macOS: `~/Documents/learner/papers` + /// - Windows: `Documents\learner\papers` + /// - Fallback: `/papers` + /// + /// The method ensures the path is absolute by: + /// - Using platform-specific document directories when available + /// - Falling back to the current working directory when needed + /// - Resolving all relative components + /// + /// Users can override this default using [`Database::set_storage_path()`]. + /// + /// # Returns + /// + /// Returns an absolute [`PathBuf`] pointing to the default document storage + /// location for the current platform. + /// + /// # Examples + /// + /// ```no_run + /// use learner::database::Database; + /// + /// let path = Database::default_storage_path(); + /// assert!(path.is_absolute()); + /// println!("Default document storage: {}", path.display()); + /// + /// // On Unix-like systems, might print something like: + /// // "/home/user/Documents/learner/papers" + /// + /// // On Windows, might print something like: + /// // "C:\Users\user\Documents\learner\papers" + /// ``` + /// + /// Note that while the base directory may vary by platform, the returned path + /// is guaranteed to be absolute and usable for document storage. + pub fn default_storage_path() -> PathBuf { + let base_path = dirs::document_dir().unwrap_or_else(|| PathBuf::from(".")); + // Make sure we return an absolute path + if !base_path.is_absolute() { + std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")).join(base_path) + } else { + base_path + } + .join("learner") + .join("papers") + } +} diff --git a/crates/learner/src/database/tests.rs b/crates/learner/src/database/tests.rs new file mode 100644 index 0000000..63ee619 --- /dev/null +++ b/crates/learner/src/database/tests.rs @@ -0,0 +1,149 @@ +use super::*; + +/// Helper function to set up a test database +async fn setup_test_db() -> (Database, PathBuf, tempfile::TempDir) { + let dir = tempdir().unwrap(); + let path = dir.path().join("test.db"); + let db = Database::open(&path).await.unwrap(); + (db, path, dir) +} + +#[traced_test] +#[tokio::test] +async fn test_database_creation() { + let (_db, path, _dir) = setup_test_db().await; + + // Check that file exists + assert!(path.exists()); +} + +#[traced_test] +#[test] +fn test_default_path() { + let path = Database::default_path(); + + // Should end with learner/papers + assert!(path.ends_with("learner/learner.db") || path.ends_with("learner\\learner.db")); + + // Should be rooted in a valid directory + assert!(path + .parent() + .unwrap() + .starts_with(dirs::data_dir().unwrap_or_else(|| PathBuf::from(".")))); +} + +#[traced_test] +#[test] +fn test_default_storage_path() { + let path = Database::default_storage_path(); + + // Should end with learner/papers + assert!(path.ends_with("learner/papers") || path.ends_with("learner\\papers")); + + // In CI, we might get a path based on current directory instead of user directories + // so we should check both possibilities + if let Some(doc_dir) = dirs::document_dir() { + assert!( + path.parent().unwrap().starts_with(&doc_dir) + || path.parent().unwrap().starts_with(std::env::current_dir().unwrap()) + ); + } else { + // If no document directory is available (like in CI), + // path should be based on current directory + assert!(path.parent().unwrap().starts_with(std::env::current_dir().unwrap())); + } +} + +#[traced_test] +#[tokio::test] +async fn test_new_db_uses_default_storage() { + let (db, _path, _dir) = setup_test_db().await; + + let storage_path = db.get_storage_path().await.expect("Storage path should be set"); + assert_eq!(storage_path, Database::default_storage_path()); +} + +#[traced_test] +#[tokio::test] +async fn test_storage_path_persistence() { + let (db, db_path, _dir) = setup_test_db().await; + + // Set custom storage path + let custom_path = PathBuf::from("/tmp/custom/storage"); + db.set_storage_path(&custom_path).await.unwrap(); + + // Reopen database and check path + drop(db); + let db = Database::open(db_path).await.unwrap(); + let storage_path = db.get_storage_path().await.expect("Storage path should be set"); + assert_eq!(storage_path, custom_path); +} + +#[traced_test] +#[tokio::test] +async fn test_storage_path_creates_directory() { + let (db, _path, dir) = setup_test_db().await; + + let custom_path = dir.path().join("custom_storage"); + db.set_storage_path(&custom_path).await.unwrap(); + + assert!(custom_path.exists()); + assert!(custom_path.is_dir()); +} + +#[traced_test] +#[tokio::test] +async fn test_storage_path_valid() -> Result<()> { + let (db, _path, dir) = setup_test_db().await; + + // Create an absolute path without requiring existence + let test_path = dir.path().join("storage"); + + // Set the storage path (this will create the directory) + db.set_storage_path(&test_path).await?; + + // Get and verify the stored path + let stored_path = db.get_storage_path().await?; + assert_eq!(stored_path, test_path); + assert!(test_path.exists()); + + // Verify we can write to the directory + let test_file = test_path.join("test.txt"); + std::fs::write(&test_file, b"test")?; + assert!(test_file.exists()); + + Ok(()) +} + +#[traced_test] +#[tokio::test] +async fn test_storage_path_relative() { + let (db, _path, _dir) = setup_test_db().await; + let storage_path = "relative/path"; + db.set_storage_path(storage_path).await.unwrap(); + assert_eq!( + std::env::current_dir().unwrap().join(storage_path).to_str().unwrap(), + db.get_storage_path().await.unwrap().to_str().unwrap() + ); +} + +#[cfg(unix)] +#[traced_test] +#[tokio::test] +async fn test_storage_path_readonly() -> Result<()> { + use std::os::unix::fs::PermissionsExt; + + let (db, _path, dir) = setup_test_db().await; + let test_path = dir.path().join("readonly"); + std::fs::create_dir(&test_path)?; + + // Make directory read-only + std::fs::set_permissions(&test_path, std::fs::Permissions::from_mode(0o444))?; + + let result = db.set_storage_path(&test_path).await; + assert!(matches!( + result, + Err(LearnerError::Path(e)) if e.kind() == std::io::ErrorKind::PermissionDenied + )); + Ok(()) +} diff --git a/crates/learner/src/error.rs b/crates/learner/src/error.rs index 6176ddf..99ea26f 100644 --- a/crates/learner/src/error.rs +++ b/crates/learner/src/error.rs @@ -161,46 +161,14 @@ pub enum LearnerError { #[error("No messages were supplied to send to the LLM.")] LLMMissingMessage, - // TODO (autoparallel): This can be gotten rid of if we use an enum to handle the ways to sort - // data in the database instead of a string. - /// An error working with the database. - #[error("{0}")] - Database(String), -} - -impl LearnerError { - /// Checks if this error represents a duplicate entry in the database. - /// - /// This helper method checks for SQLite's unique constraint violation, which - /// occurs when trying to insert a paper that already exists in the database - /// (matching source and source_identifier). - /// - /// # Examples - /// - /// ``` - /// use learner::error::LearnerError; - /// - /// # async fn example() -> Result<(), Box> { - /// let db = learner::database::Database::open("papers.db").await?; - /// let paper = learner::paper::Paper::new("2301.07041").await?; - /// - /// match paper.save(&db).await { - /// Ok(id) => println!("Saved paper with ID: {}", id), - /// Err(e) if e.is_duplicate_error() => println!("Paper already exists!"), - /// Err(e) => return Err(e.into()), - /// } - /// # Ok(()) - /// # } - /// ``` - /// - /// This is particularly useful for providing friendly error messages when - /// attempting to add papers that are already in the database. - pub fn is_duplicate_error(&self) -> bool { - matches!( - self, - LearnerError::AsyncSqlite(tokio_rusqlite::Error::Rusqlite( - rusqlite::Error::SqliteFailure(error, _) - )) if error.code == rusqlite::ErrorCode::ConstraintViolation - ) - } + /// Indicates an attempt to add a paper that already exists in the database. + /// + /// This error occurs during paper addition operations when the database + /// already contains a paper with the same source and identifier. This helps + /// prevent duplicate entries and maintains database integrity. + /// + /// The error includes the paper's title to help users identify which paper + /// caused the conflict. + #[error("Tried to add a paper titled \"{0}\" that was already in the database.")] + DatabaseDuplicatePaper(String), } diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index db4f2b8..a164718 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -1,19 +1,69 @@ -//! A library for fetching academic papers and their metadata from various sources -//! including arXiv, IACR, and DOI-based repositories. +//! Academic paper management and metadata retrieval library. +//! +//! `learner` is a library for managing academic papers, providing: +//! +//! - Paper metadata retrieval from multiple sources +//! - Local document management +//! - Database storage and querying +//! - Full-text search capabilities +//! - Flexible document organization +//! +//! # Features +//! +//! - **Multi-source support**: Fetch papers from: +//! - arXiv (with support for both new and old-style identifiers) +//! - IACR (International Association for Cryptologic Research) +//! - DOI (Digital Object Identifier) +//! - **Flexible storage**: Choose where and how to store documents +//! - **Rich metadata**: Track authors, abstracts, and publication dates +//! - **Database operations**: Type-safe queries and modifications +//! - **Command pattern**: Composable database operations +//! +//! # Getting Started //! -//! # Example //! ```no_run -//! use learner::paper::{Paper, Source}; +//! use learner::{ +//! database::{Add, Database, Query}, +//! paper::{Paper, Source}, +//! prelude::*, +//! }; //! //! #[tokio::main] //! async fn main() -> Result<(), Box> { -//! // Fetch from arXiv +//! // Create or open a database +//! let mut db = Database::open(Database::default_path()).await?; +//! +//! // Fetch a paper from arXiv //! let paper = Paper::new("2301.07041").await?; //! println!("Title: {}", paper.title); //! +//! // Add to database with document +//! Add::complete(&paper).execute(&mut db).await?; +//! +//! // Search for related papers +//! let papers = Query::text("quantum computing").execute(&mut db).await?; +//! //! Ok(()) //! } //! ``` +//! +//! # Module Organization +//! +//! - [`paper`]: Core paper types and metadata handling +//! - [`database`]: Database operations and storage management +//! - [`clients`]: Source-specific API clients +//! - [`llm`]: Language model integration for paper analysis +//! - [`pdf`]: PDF document handling and text extraction +//! - [`prelude`]: Common traits and types for ergonomic imports +//! +//! # Design Philosophy +//! +//! This library emphasizes: +//! - User control over document storage and organization +//! - Separation of metadata from document management +//! - Type-safe database operations +//! - Extensible command pattern for operations +//! - Clear error handling and propagation #![warn(missing_docs, clippy::missing_docs_in_private_items)] #![feature(str_from_utf16_endian)] @@ -36,10 +86,53 @@ use {tempfile::tempdir, tracing_test::traced_test}; pub mod clients; pub mod database; + pub mod error; pub mod format; pub mod llm; pub mod paper; pub mod pdf; -use crate::{clients::*, database::*, error::*}; +use crate::{clients::*, error::*}; + +/// Common traits and types for ergonomic imports. +/// +/// This module provides a convenient way to import frequently used traits +/// and types with a single glob import. It includes: +/// +/// - Database operation traits +/// - Error types and common `Result` type +/// - Commonly used trait implementations +/// +/// # Usage +/// +/// ```no_run +/// use learner::{ +/// database::{Add, Database}, +/// paper::Paper, +/// prelude::*, +/// }; +/// +/// async fn example() -> Result<(), LearnerError> { +/// // Now you can use both `DatabaseInstruction` and our `LearnerError`` type +/// let paper = Paper::new("2301.07041").await?; +/// let mut db = Database::open(Database::default_path()).await?; +/// Add::paper(&paper).execute(&mut db).await?; +/// Ok(()) +/// } +/// ``` +/// +/// # Contents +/// +/// Currently exports: +/// - [`DatabaseInstruction`]: Trait for implementing database operations +/// - [`LearnerError`]: Core error type for the library +/// +/// Future additions may include: +/// - Additional trait implementations +/// - Common type aliases +/// - Builder pattern traits +/// - Conversion traits +pub mod prelude { + pub use crate::{database::DatabaseInstruction, error::LearnerError}; +} diff --git a/crates/learner/src/paper.rs b/crates/learner/src/paper.rs index f8f219a..caf9157 100644 --- a/crates/learner/src/paper.rs +++ b/crates/learner/src/paper.rs @@ -1,161 +1,204 @@ -//! Paper management and metadata types for the learner library. +//! Core paper management and metadata types for academic paper handling. //! -//! This module provides types and functionality for working with academic papers from -//! various sources including arXiv, IACR, and DOI-based repositories. It handles paper -//! metadata, author information, and source-specific identifier parsing. +//! This module provides the fundamental types and functionality for working with +//! academic papers from various sources. It handles: +//! +//! - Paper metadata management +//! - Multi-source identifier parsing +//! - Author information +//! - Document downloading +//! - Source-specific identifier formats +//! +//! The implementation supports papers from: +//! - arXiv (both new-style and old-style identifiers) +//! - IACR (International Association for Cryptologic Research) +//! - DOI (Digital Object Identifier) //! //! # Examples //! +//! Creating papers from different sources: +//! //! ```no_run //! use learner::paper::Paper; //! -//! # async fn run() -> Result<(), Box> { -//! // Create a paper from an arXiv URL +//! # async fn example() -> Result<(), Box> { +//! // From arXiv URL //! let paper = Paper::new("https://arxiv.org/abs/2301.07041").await?; //! println!("Title: {}", paper.title); //! -//! // Or from a DOI +//! // From DOI //! let paper = Paper::new("10.1145/1327452.1327492").await?; //! -//! // Save to database -//! let db = learner::database::Database::open("papers.db").await?; -//! paper.save(&db).await?; +//! // From IACR +//! let paper = Paper::new("2023/123").await?; +//! +//! // Download associated PDF +//! use std::path::PathBuf; +//! let storage = PathBuf::from("papers"); +//! paper.download_pdf(&storage).await?; //! # Ok(()) //! # } //! ``` use super::*; -/// The source repository or system from which a paper originates. +/// Complete representation of an academic paper with metadata. /// -/// This enum represents the supported academic paper sources, each with its own -/// identifier format and access patterns. -#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] -pub enum Source { - /// Papers from arxiv.org, using either new-style (2301.07041) or - /// old-style (math.AG/0601001) identifiers - Arxiv, - /// Papers from the International Association for Cryptologic Research (eprint.iacr.org) - IACR, - /// Papers identified by a Digital Object Identifier (DOI) - DOI, -} - -impl Display for Source { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Source::Arxiv => write!(f, "Arxiv"), - Source::IACR => write!(f, "IACR"), - Source::DOI => write!(f, "DOI"), - } - } -} - -impl FromStr for Source { - type Err = LearnerError; - - fn from_str(s: &str) -> Result { - match &s.to_lowercase() as &str { - "arxiv" => Ok(Source::Arxiv), - "iacr" => Ok(Source::IACR), - "doi" => Ok(Source::DOI), - s => Err(LearnerError::InvalidSource(s.to_owned())), - } - } -} - -/// Represents an author of an academic paper. +/// This struct serves as the core data type for paper management, containing +/// all relevant metadata and document references. It supports papers from +/// multiple sources while maintaining a consistent interface for: /// -/// Contains the author's name and optional affiliation and contact information. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Author { - /// The author's full name - pub name: String, - /// The author's institutional affiliation, if available - pub affiliation: Option, - /// The author's email address, if available - pub email: Option, -} - -/// A complete academic paper with its metadata. +/// - Basic metadata (title, abstract, dates) +/// - Author information +/// - Source-specific identifiers +/// - Document access /// -/// This struct represents a paper from any supported source (arXiv, IACR, DOI) -/// along with its metadata including title, authors, abstract, and identifiers. +/// Papers can be created from various identifier formats and URLs, with the +/// appropriate source being automatically detected. /// /// # Examples /// +/// Creating and using papers: +/// /// ```no_run +/// # use learner::paper::Paper; /// # async fn example() -> Result<(), Box> { -/// // Fetch a paper from arXiv -/// let paper = learner::paper::Paper::new("2301.07041").await?; +/// // Create from identifier +/// let paper = Paper::new("2301.07041").await?; /// /// // Access metadata /// println!("Title: {}", paper.title); /// println!("Authors: {}", paper.authors.len()); /// println!("Abstract: {}", paper.abstract_text); /// -/// // Download the PDF if available -/// if let Some(pdf_url) = &paper.pdf_url { -/// paper.download_pdf("paper.pdf".into()).await?; +/// // Handle documents +/// if let Some(url) = &paper.pdf_url { +/// println!("PDF available at: {}", url); /// } /// # Ok(()) /// # } /// ``` #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Paper { - /// The paper's title + /// The paper's full title pub title: String, - /// List of the paper's authors + /// Complete list of paper authors with affiliations pub authors: Vec, - /// The paper's abstract text + /// Full abstract or summary text pub abstract_text: String, - /// When the paper was published or last updated + /// Publication or last update timestamp pub publication_date: DateTime, - /// The source system (arXiv, IACR, DOI) + /// Source repository or system (arXiv, IACR, DOI) pub source: Source, - /// The source-specific identifier (e.g., arXiv ID, DOI) + /// Source-specific paper identifier pub source_identifier: String, - /// URL to the paper's PDF, if available + /// Optional URL to PDF document pub pdf_url: Option, - /// The paper's DOI, if available + /// Optional DOI reference pub doi: Option, } +/// Author information for academic papers. +/// +/// Represents a single author of a paper, including their name and optional +/// institutional details. This struct supports varying levels of author +/// information availability across different sources. +/// +/// # Examples +/// +/// ``` +/// use learner::paper::Author; +/// +/// let author = Author { +/// name: "Alice Researcher".to_string(), +/// affiliation: Some("Example University".to_string()), +/// email: Some("alice@example.edu".to_string()), +/// }; +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Author { + /// Author's full name + pub name: String, + /// Optional institutional affiliation + pub affiliation: Option, + /// Optional contact email + pub email: Option, +} + +/// Paper source system or repository. +/// +/// Represents the different systems from which papers can be retrieved, +/// each with its own identifier format and access patterns. The enum +/// supports: +/// +/// - arXiv: Both new (2301.07041) and old (math.AG/0601001) formats +/// - IACR: Cryptology ePrint Archive format (2023/123) +/// - DOI: Standard DOI format (10.1145/1327452.1327492) +/// +/// # Examples +/// +/// ``` +/// use std::str::FromStr; +/// +/// use learner::paper::Source; +/// +/// let arxiv = Source::from_str("arxiv").unwrap(); +/// let doi = Source::from_str("doi").unwrap(); +/// ``` +#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] +pub enum Source { + /// arXiv.org papers (e.g., "2301.07041" or "math.AG/0601001") + Arxiv, + /// IACR Cryptology ePrint Archive papers (e.g., "2023/123") + IACR, + /// Papers with Digital Object Identifiers + DOI, +} + impl Paper { - /// Create a new paper from a URL, identifier, or DOI. + // TODO (autoparallel): This should probably be a `new_from_url` or just `from_url` or something. + /// Creates a new paper from various identifier formats. + /// + /// This method serves as the primary entry point for paper creation, + /// supporting multiple input formats and automatically determining the + /// appropriate source handler. It accepts: /// - /// This method accepts various formats for paper identification and automatically - /// determines the appropriate source and fetches the paper's metadata. + /// - Full URLs from supported repositories + /// - Direct identifiers (arXiv ID, DOI, IACR ID) + /// - Both new and legacy identifier formats + /// + /// The method will fetch metadata from the appropriate source and + /// construct a complete Paper instance. /// /// # Arguments /// - /// * `input` - One of the following: - /// - An arXiv URL (e.g., "https://arxiv.org/abs/2301.07041") - /// - An arXiv ID (e.g., "2301.07041" or "math.AG/0601001") - /// - An IACR URL (e.g., "https://eprint.iacr.org/2016/260") - /// - An IACR ID (e.g., "2023/123") - /// - A DOI URL (e.g., "https://doi.org/10.1145/1327452.1327492") - /// - A DOI (e.g., "10.1145/1327452.1327492") + /// * `input` - Paper identifier in any supported format: + /// - arXiv URLs: "https://arxiv.org/abs/2301.07041" + /// - arXiv IDs: "2301.07041" or "math.AG/0601001" + /// - IACR URLs: "https://eprint.iacr.org/2016/260" + /// - IACR IDs: "2023/123" + /// - DOI URLs: "https://doi.org/10.1145/1327452.1327492" + /// - DOIs: "10.1145/1327452.1327492" /// /// # Returns /// /// Returns a `Result` which is: - /// - `Ok(Paper)` - Successfully fetched paper with metadata - /// - `Err(LearnerError)` - Failed to parse input or fetch paper + /// - `Ok(Paper)` - Successfully created paper with metadata + /// - `Err(LearnerError)` - Failed to parse input or fetch metadata /// /// # Examples /// /// ```no_run /// # use learner::paper::Paper; /// # async fn example() -> Result<(), Box> { - /// // From arXiv URL - /// let paper1 = Paper::new("https://arxiv.org/abs/2301.07041").await?; + /// // From URL + /// let paper = Paper::new("https://arxiv.org/abs/2301.07041").await?; /// - /// // From arXiv ID - /// let paper2 = Paper::new("2301.07041").await?; + /// // From identifier + /// let paper = Paper::new("2301.07041").await?; /// /// // From DOI - /// let paper3 = Paper::new("10.1145/1327452.1327492").await?; + /// let paper = Paper::new("10.1145/1327452.1327492").await?; /// # Ok(()) /// # } /// ``` @@ -208,61 +251,114 @@ impl Paper { } } - /// Download the paper's PDF to a specified path. + /// Downloads the paper's PDF to the specified directory. + /// + /// This method handles the retrieval and storage of the paper's PDF + /// document, if available. It will: + /// + /// 1. Check for PDF availability + /// 2. Download the document + /// 3. Store it with a formatted filename + /// 4. Handle network and storage errors /// /// # Arguments /// - /// * `path` - The filesystem path where the PDF should be saved + /// * `dir` - Target directory for PDF storage + /// + /// # Returns /// - /// # Errors + /// Returns a `Result` containing: + /// - `Ok(PathBuf)` - Path to the stored PDF file + /// - `Err(LearnerError)` - If download or storage fails /// - /// Returns `LearnerError` if: - /// - The paper has no PDF URL available - /// - The download fails - /// - Writing to the specified path fails - pub async fn download_pdf(&self, dir: PathBuf) -> Result<()> { - // unimplemented!("Work in progress -- needs integrated with `Database`"); + /// # Examples + /// + /// ```no_run + /// # use learner::paper::Paper; + /// # use std::path::PathBuf; + /// # async fn example() -> Result<(), Box> { + /// let paper = Paper::new("2301.07041").await?; + /// let dir = PathBuf::from("papers"); + /// let pdf_path = paper.download_pdf(&dir).await?; + /// println!("PDF stored at: {}", pdf_path.display()); + /// # Ok(()) + /// # } + /// ``` + pub async fn download_pdf(&self, dir: &Path) -> Result { let Some(pdf_url) = &self.pdf_url else { return Err(LearnerError::ApiError("No PDF URL available".into())); }; let response = reqwest::get(pdf_url).await?; - trace!("{} pdf_url response: {response:?}", self.source); - let bytes = response.bytes().await?; - // TODO (autoparallel): uses a fixed max output filename length, should make this configurable - // in the future. - let formatted_title = format::format_title(&self.title, Some(50)); - let path = dir.join(format!("{}.pdf", formatted_title)); - debug!("Writing PDF to path: {path:?}"); - std::fs::write(path, bytes)?; - Ok(()) + // Check the status code of the response + if response.status().is_success() { + let bytes = response.bytes().await?; + let path = dir.join(self.filename()); + debug!("Writing PDF to path: {path:?}"); + std::fs::write(path, bytes)?; + Ok(self.filename()) + } else { + // Handle non-successful status codes + trace!("{} pdf_url response: {response:?}", self.source); + Err(LearnerError::ApiError(format!("Failed to download PDF: {}", response.status()))) + } } - /// Save the paper to a database. - /// - /// # Arguments + /// Generates a standardized filename for the paper's PDF. /// - /// * `db` - Reference to an open database connection + /// Creates a filesystem-safe filename based on the paper's title, + /// suitable for PDF storage. The filename is: + /// - Truncated to a reasonable length + /// - Cleaned of problematic characters + /// - Suffixed with ".pdf" /// /// # Returns /// - /// Returns the database ID of the saved paper on success. + /// Returns a [`PathBuf`] containing the formatted filename. /// /// # Examples /// /// ```no_run + /// # use learner::paper::Paper; /// # async fn example() -> Result<(), Box> { - /// let paper = learner::paper::Paper::new("2301.07041").await?; - /// let db = learner::database::Database::open("papers.db").await?; - /// let id = paper.save(&db).await?; - /// println!("Saved paper with ID: {}", id); + /// let paper = Paper::new("2301.07041").await?; + /// let filename = paper.filename(); + /// println!("Suggested filename: {}", filename.display()); /// # Ok(()) /// # } /// ``` - pub async fn save(&self, db: &Database) -> Result { db.save_paper(self).await } + pub fn filename(&self) -> PathBuf { + let formatted_title = format::format_title(&self.title, Some(50)); + PathBuf::from(format!("{}.pdf", formatted_title)) + } } +impl Display for Source { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Source::Arxiv => write!(f, "Arxiv"), + Source::IACR => write!(f, "IACR"), + Source::DOI => write!(f, "DOI"), + } + } +} + +impl FromStr for Source { + type Err = LearnerError; + + fn from_str(s: &str) -> Result { + match &s.to_lowercase() as &str { + "arxiv" => Ok(Source::Arxiv), + "iacr" => Ok(Source::IACR), + "doi" => Ok(Source::DOI), + s => Err(LearnerError::InvalidSource(s.to_owned())), + } + } +} + +// TODO (autoparallel): These three functions should really be some simple generic alongside the +// rest of the stuff we have in here /// Extracts the arXiv identifier from a URL. /// /// Parses URLs like "https://arxiv.org/abs/2301.07041" to extract "2301.07041". @@ -319,67 +415,61 @@ mod tests { #[traced_test] #[tokio::test] - async fn test_iacr_paper_from_id() -> anyhow::Result<()> { - let paper = Paper::new("2016/260").await?; + async fn test_iacr_paper_from_id() { + let paper = Paper::new("2016/260").await.unwrap(); assert!(!paper.title.is_empty()); assert!(!paper.authors.is_empty()); assert_eq!(paper.source, Source::IACR); - Ok(()) } #[traced_test] #[tokio::test] - async fn test_iacr_paper_from_url() -> anyhow::Result<()> { - let paper = Paper::new("https://eprint.iacr.org/2016/260").await?; + async fn test_iacr_paper_from_url() { + let paper = Paper::new("https://eprint.iacr.org/2016/260").await.unwrap(); assert!(!paper.title.is_empty()); assert!(!paper.authors.is_empty()); assert_eq!(paper.source, Source::IACR); - Ok(()) } #[traced_test] #[tokio::test] - async fn test_doi_paper_from_id() -> anyhow::Result<()> { - let paper = Paper::new("10.1145/1327452.1327492").await?; + async fn test_doi_paper_from_id() { + let paper = Paper::new("10.1145/1327452.1327492").await.unwrap(); assert!(!paper.title.is_empty()); assert!(!paper.authors.is_empty()); assert_eq!(paper.source, Source::DOI); - Ok(()) } #[traced_test] #[tokio::test] - async fn test_doi_paper_from_url() -> anyhow::Result<()> { - let paper = Paper::new("https://doi.org/10.1145/1327452.1327492").await?; + async fn test_doi_paper_from_url() { + let paper = Paper::new("https://doi.org/10.1145/1327452.1327492").await.unwrap(); assert!(!paper.title.is_empty()); assert!(!paper.authors.is_empty()); assert_eq!(paper.source, Source::DOI); - Ok(()) } #[traced_test] #[tokio::test] - async fn test_arxiv_pdf_from_paper() -> anyhow::Result<()> { + async fn test_arxiv_pdf_from_paper() { let paper = Paper::new("https://arxiv.org/abs/2301.07041").await.unwrap(); let dir = tempdir().unwrap(); - paper.download_pdf(dir.path().to_path_buf()).await.unwrap(); + paper.download_pdf(dir.path()).await.unwrap(); let formatted_title = format::format_title("Verifiable Fully Homomorphic Encryption", Some(50)); let path = dir.into_path().join(format!("{}.pdf", formatted_title)); assert!(path.exists()); - Ok(()) } #[traced_test] #[tokio::test] - async fn test_iacr_pdf_from_paper() -> anyhow::Result<()> { + async fn test_iacr_pdf_from_paper() { let paper = Paper::new("https://eprint.iacr.org/2016/260").await.unwrap(); let dir = tempdir().unwrap(); - paper.download_pdf(dir.path().to_path_buf()).await.unwrap(); + paper.download_pdf(dir.path()).await.unwrap(); let formatted_title = format::format_title("On the Size of Pairing-based Non-interactive Arguments", Some(50)); let path = dir.into_path().join(format!("{}.pdf", formatted_title)); assert!(path.exists()); - Ok(()) } // TODO (autoparallel): This technically passes, but it is not actually getting a PDF from this @@ -387,21 +477,24 @@ mod tests { #[ignore] #[traced_test] #[tokio::test] - async fn test_doi_pdf_from_paper() -> anyhow::Result<()> { + async fn test_doi_pdf_from_paper() { let paper = Paper::new("https://doi.org/10.1145/1327452.1327492").await.unwrap(); dbg!(&paper); let dir = tempdir().unwrap(); - paper.download_pdf(dir.path().to_path_buf()).await.unwrap(); - let formatted_title = - format::format_title("MapReduce: simplified data processing on large clusters", Some(50)); - let path = dir.into_path().join(format!("{}.pdf", formatted_title)); + paper.download_pdf(dir.path()).await.unwrap(); + let path = dir.into_path().join(paper.filename()); assert!(path.exists()); - Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_broken_api_link() { + assert!(Paper::new("https://arxiv.org/abs/2401.00000").await.is_err()); } // TODO (autoparallel): Convenient entrypoint to try seeing if the PDF comes out correct. What I // have tried now is using a `reqwest` client with ``` - // let _ = client.get("https://dl.acm.org/").send().await?; + // let _ = client.get("https://dl.acm.org/").send().await.unwrap(); // // let response = client // .get(pdf_url) diff --git a/crates/learner/tests/database/find.rs b/crates/learner/tests/database/find.rs deleted file mode 100644 index 4189b8b..0000000 --- a/crates/learner/tests/database/find.rs +++ /dev/null @@ -1,96 +0,0 @@ -use super::*; - -#[traced_test] -#[tokio::test] -async fn test_full_text_search() { - let (db, _dir) = setup_test_db().await; - - // Save a few papers - let mut paper1 = create_test_paper(); - paper1.title = "Neural Networks in Machine Learning".to_string(); - paper1.abstract_text = "This paper discusses deep learning".to_string(); - paper1.source_identifier = "2401.00001".to_string(); - - let mut paper2 = create_test_paper(); - paper2.title = "Advanced Algorithms".to_string(); - paper2.abstract_text = "Classical computer science topics".to_string(); - paper2.source_identifier = "2401.00002".to_string(); - - db.save_paper(&paper1).await.unwrap(); - db.save_paper(&paper2).await.unwrap(); - - // Search for papers - let results = db.search_papers("neural").await.unwrap(); - assert_eq!(results.len(), 1); - assert_eq!(results[0].title, paper1.title); - - let results = db.search_papers("learning").await.unwrap(); - assert_eq!(results.len(), 1); - assert_eq!(results[0].source_identifier, paper1.source_identifier); - - let results = db.search_papers("algorithms").await.unwrap(); - assert_eq!(results.len(), 1); - assert_eq!(results[0].title, paper2.title); -} - -#[traced_test] -#[tokio::test] -async fn test_pdf_status_nonexistent() { - let (db, _dir) = setup_test_db().await; - let paper = create_test_paper(); - - // Save paper first to get an ID - let paper_id = db.save_paper(&paper).await.unwrap(); - - // Test getting status for paper with no PDF record - let status = db.get_pdf_status(paper_id).await.unwrap(); - assert_eq!(status, None); -} - -#[traced_test] -#[tokio::test] -async fn test_pdf_status_update() { - let (db, _dir) = setup_test_db().await; - let paper = create_test_paper(); - - // Save paper first to get an ID - let paper_id = db.save_paper(&paper).await.unwrap(); - - let path = PathBuf::from("/test/path/paper.pdf"); - let filename = "paper.pdf".to_string(); - - // First record as pending - db.record_pdf(paper_id, path.clone(), filename.clone(), "pending", None).await.unwrap(); - - // Then update to success - db.record_pdf(paper_id, path.clone(), filename.clone(), "success", None).await.unwrap(); - - // Verify final status - let status = db.get_pdf_status(paper_id).await.unwrap(); - let (_, _, stored_status, _) = status.unwrap(); - assert_eq!(stored_status, "success"); -} - -#[traced_test] -#[tokio::test] -async fn test_pdf_list_papers() -> Result<(), Box> { - let (db, _dir) = setup_test_db().await; - - // Save a few papers - let mut paper1 = create_test_paper(); - paper1.title = "Neural Networks in Machine Learning".to_string(); - paper1.abstract_text = "This paper discusses deep learning".to_string(); - paper1.source_identifier = "2401.00001".to_string(); - - let mut paper2 = create_test_paper(); - paper2.title = "Advanced Algorithms".to_string(); - paper2.abstract_text = "Classical computer science topics".to_string(); - paper2.source_identifier = "2401.00002".to_string(); - - db.save_paper(&paper1).await.unwrap(); - db.save_paper(&paper2).await.unwrap(); - - let papers = db.list_papers("title", true).await?; - assert_eq!(papers.len(), 2); - Ok(()) -} diff --git a/crates/learner/tests/database/instruction/add.rs b/crates/learner/tests/database/instruction/add.rs new file mode 100644 index 0000000..44addc2 --- /dev/null +++ b/crates/learner/tests/database/instruction/add.rs @@ -0,0 +1,213 @@ +use super::*; + +/// Basic paper addition tests +mod basic_operations { + + use super::*; + + #[traced_test] + #[tokio::test] + async fn test_add_paper() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let paper = create_test_paper(); + + let papers = Add::paper(&paper).execute(&mut db).await?; + assert_eq!(papers.len(), 1); + assert_eq!(papers[0].title, paper.title); + + // Verify paper exists in database + let stored = Query::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + assert_eq!(stored.len(), 1); + assert_eq!(stored[0].title, paper.title); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_add_paper_twice() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let paper = create_test_paper(); + + Add::paper(&paper).execute(&mut db).await?; + let err = Add::paper(&paper).execute(&mut db).await.unwrap_err(); + + assert!(matches!(err, LearnerError::DatabaseDuplicatePaper(_))); + + // Verify only one copy exists + let stored = Query::list_all().execute(&mut db).await?; + assert_eq!(stored.len(), 1); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_add_paper_with_authors() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let mut paper = create_test_paper(); + paper.authors = vec![ + Author { + name: "Test Author 1".into(), + affiliation: Some("University 1".into()), + email: Some("email1@test.com".into()), + }, + Author { name: "Test Author 2".into(), affiliation: None, email: None }, + ]; + + Add::paper(&paper).execute(&mut db).await?; + + // Verify authors were stored + let stored = Query::by_author("Test Author 1").execute(&mut db).await?; + assert_eq!(stored.len(), 1); + assert_eq!(stored[0].authors.len(), 2); + assert_eq!(stored[0].authors[0].affiliation, Some("University 1".into())); + assert_eq!(stored[0].authors[1].name, "Test Author 2"); + + Ok(()) + } +} + +/// Tests for paper addition with documents +mod document_operations { + use learner::paper::Paper; + + use super::*; + + #[traced_test] + #[tokio::test] + async fn test_add_complete_paper() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let paper = Paper::new("https://arxiv.org/abs/2301.07041").await?; + + let papers = Add::complete(&paper).execute(&mut db).await?; + assert_eq!(papers.len(), 1); + + // Verify both paper and document were added + let stored = Query::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + assert_eq!(stored.len(), 1); + + // Verify PDF exists in storage location + let storage_path = db.get_storage_path().await?; + let pdf_path = storage_path.join(paper.filename()); + assert!(pdf_path.exists(), "PDF file should exist at {:?}", pdf_path); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_add_paper_then_document() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let paper = Paper::new("https://arxiv.org/abs/2301.07041").await?; + + // First add paper only + Add::paper(&paper).execute(&mut db).await?; + + // Then add with document + let papers = Add::complete(&paper).execute(&mut db).await?; + assert_eq!(papers.len(), 1); + + // Verify PDF exists + let storage_path = db.get_storage_path().await?; + let pdf_path = storage_path.join(paper.filename()); + assert!(pdf_path.exists()); + + assert!(logs_contain( + "WARN test_add_paper_then_document: learner::database::instruction::add: Tried to add \ + complete paper when paper existed in database already, attempting to add only the document!" + )); + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_chain_document_addition() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let paper = Paper::new("https://arxiv.org/abs/2301.07041").await?; + + let papers = Add::paper(&paper).with_document().execute(&mut db).await?; + assert_eq!(papers.len(), 1); + + // Verify PDF exists + let storage_path = db.get_storage_path().await?; + let pdf_path = storage_path.join(paper.filename()); + assert!(pdf_path.exists()); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_add_documents_by_query() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + // Add multiple papers without documents + let paper1 = Paper::new("https://arxiv.org/abs/2301.07041").await?; + let paper2 = Paper::new("https://eprint.iacr.org/2016/260").await?; + Add::paper(&paper1).execute(&mut db).await?; + Add::paper(&paper2).execute(&mut db).await?; + + // Add documents for all papers + let papers = Add::documents(Query::list_all()).execute(&mut db).await?; + assert_eq!(papers.len(), 2); + + // Verify PDFs exist + let storage_path = db.get_storage_path().await?; + for paper in papers { + let pdf_path = storage_path.join(paper.filename()); + assert!(pdf_path.exists(), "PDF should exist for {}", paper.source_identifier); + } + + Ok(()) + } +} + +/// Edge case tests +mod edge_cases { + use super::*; + + #[traced_test] + #[tokio::test] + async fn test_add_paper_with_special_characters() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let mut paper = create_test_paper(); + paper.title = "Test & Paper: A Study!".into(); + paper.abstract_text = "Abstract with & and other symbols: @#$%".into(); + + let papers = Add::paper(&paper).execute(&mut db).await?; + assert_eq!(papers.len(), 1); + assert_eq!(papers[0].title, paper.title); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_add_empty_author_list() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let mut paper = create_test_paper(); + paper.authors.clear(); + + let papers = Add::paper(&paper).execute(&mut db).await?; + assert_eq!(papers.len(), 1); + assert!(papers[0].authors.is_empty()); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_add_paper_with_optional_fields() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let mut paper = create_test_paper(); + paper.doi = Some("10.1234/test".into()); + paper.pdf_url = Some("https://example.com/paper.pdf".into()); + + let papers = Add::paper(&paper).execute(&mut db).await?; + assert_eq!(papers[0].doi, Some("10.1234/test".into())); + assert_eq!(papers[0].pdf_url, Some("https://example.com/paper.pdf".into())); + + Ok(()) + } +} diff --git a/crates/learner/tests/database/instruction/mod.rs b/crates/learner/tests/database/instruction/mod.rs new file mode 100644 index 0000000..a0309b7 --- /dev/null +++ b/crates/learner/tests/database/instruction/mod.rs @@ -0,0 +1,69 @@ +use chrono::{TimeZone, Utc}; +use learner::{ + database::{Add, Database, OrderField, Query, Remove}, + paper::Source, +}; + +use super::*; + +mod add; +mod query; +mod remove; + +async fn setup_test_db() -> (Database, TempDir) { + let dir = tempdir().unwrap(); + let db_path = dir.path().join("test.db"); + let db = Database::open(&db_path).await.unwrap(); + db.set_storage_path(dir.path()).await.unwrap(); + (db, dir) +} + +/// Helper function to create a test paper +fn create_test_paper() -> Paper { + Paper { + title: "Test Paper".to_string(), + abstract_text: "This is a test abstract".to_string(), + publication_date: chrono::TimeZone::with_ymd_and_hms(&chrono::Utc, 2023, 1, 1, 0, 0, 0) + .unwrap(), + source: Source::Arxiv, + source_identifier: "2301.00000".to_string(), + pdf_url: Some("https://arxiv.org/pdf/2301.00000".to_string()), + doi: Some("10.0000/test.123".to_string()), + authors: vec![ + Author { + name: "John Doe".to_string(), + affiliation: Some("Test University".to_string()), + email: Some("john@test.edu".to_string()), + }, + Author { name: "Jane Smith".to_string(), affiliation: None, email: None }, + ], + } +} + +fn create_second_test_paper() -> Paper { + Paper { + title: "Test Paper: Two".to_string(), + abstract_text: "This is a test abstract, but again!".to_string(), + publication_date: chrono::TimeZone::with_ymd_and_hms(&chrono::Utc, 2024, 1, 1, 0, 0, 0) + .unwrap(), + source: Source::Arxiv, + source_identifier: "2401.00000".to_string(), + pdf_url: Some("https://arxiv.org/pdf/2401.00000".to_string()), + doi: Some("10.1000/test.1234".to_string()), + authors: vec![ + Author { + name: "Alice Scientist".to_string(), + affiliation: Some("Test State University".to_string()), + email: Some("john@test.edu".to_string()), + }, + Author { name: "Bob Researcher".to_string(), affiliation: None, email: None }, + ], + } +} + +#[tokio::test] +#[traced_test] +async fn test_download_test_paper_is_404() { + let paper = create_test_paper(); + assert!(paper.download_pdf(&PathBuf::from_str(".").unwrap()).await.is_err()); +} diff --git a/crates/learner/tests/database/instruction/query.rs b/crates/learner/tests/database/instruction/query.rs new file mode 100644 index 0000000..75920cd --- /dev/null +++ b/crates/learner/tests/database/instruction/query.rs @@ -0,0 +1,402 @@ +use chrono::Datelike; + +use super::*; + +/// Basic paper search functionality +mod paper_search { + use super::*; + + #[tokio::test] + #[traced_test] + async fn test_basic_paper_search() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let paper = create_test_paper(); + Add::paper(&paper).execute(&mut db).await?; + + let results = Query::by_paper(&paper).execute(&mut db).await?; + assert_eq!(results.len(), 1); + assert_eq!(results[0].title, "Test Paper"); + Ok(()) + } +} + +/// Basic text search functionality +mod text_search { + use super::*; + + #[tokio::test] + #[traced_test] + async fn test_basic_text_search() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let paper = create_test_paper(); + Add::paper(&paper).execute(&mut db).await?; + + let results = Query::text("test paper").execute(&mut db).await?; + assert_eq!(results.len(), 1); + assert_eq!(results[0].title, "Test Paper"); + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_case_insensitive_search() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let paper = create_test_paper(); + Add::paper(&paper).execute(&mut db).await?; + + let results = Query::text("TEST PAPER").execute(&mut db).await?; + assert_eq!(results.len(), 1); + assert_eq!(results[0].title, "Test Paper"); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_word_boundaries() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let mut paper = create_test_paper(); + paper.title = "Testing Paper".to_string(); + Add::paper(&paper).execute(&mut db).await?; + + let results = Query::text("test").execute(&mut db).await?; + assert_eq!(results.len(), 1); + + let results = Query::text("testing").execute(&mut db).await?; + assert_eq!(results.len(), 1); + + let results = Query::text("est").execute(&mut db).await?; + assert_eq!(results.len(), 0, "Partial word match should not work"); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_abstract_search() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let mut paper = create_test_paper(); + paper.abstract_text = "This is a unique phrase in the abstract".to_string(); + Add::paper(&paper).execute(&mut db).await?; + + // Search should only match title by default since that's what we indexed + let results = Query::text("unique phrase").execute(&mut db).await?; + assert_eq!(results.len(), 0); + + // Search for title instead + let results = Query::text("test paper").execute(&mut db).await?; + assert_eq!(results.len(), 1); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_multiple_term_search() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let mut paper = create_test_paper(); + paper.title = "Machine Learning Research".to_string(); + paper.abstract_text = "A study about neural networks".to_string(); + Add::paper(&paper).execute(&mut db).await?; + + // Each term should be searched independently in title + let results = Query::text("machine").execute(&mut db).await?; + assert_eq!(results.len(), 1, "Should match single term in title"); + + let results = Query::text("learning research").execute(&mut db).await?; + assert_eq!(results.len(), 1, "Should match multiple terms in title"); + + // Abstract text isn't searched + let results = Query::text("neural").execute(&mut db).await?; + assert_eq!(results.len(), 0, "Should not match terms in abstract"); + + Ok(()) + } +} + +/// Author search functionality +mod author_search { + use super::*; + + #[tokio::test] + async fn test_exact_author_match() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let paper = create_test_paper(); + Add::paper(&paper).execute(&mut db).await?; + + let results = Query::by_author("John Doe").execute(&mut db).await?; + assert_eq!(results.len(), 1); + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_partial_author_name() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let mut paper = create_test_paper(); + paper.authors = + vec![Author { name: "John Smith".to_string(), affiliation: None, email: None }, Author { + name: "Jane Smith".to_string(), + affiliation: None, + email: None, + }]; + Add::paper(&paper).execute(&mut db).await?; + + let results = Query::by_author("Smith").execute(&mut db).await?; + assert_eq!(results.len(), 1); + assert_eq!(results[0].authors.len(), 2); + + let results = Query::by_author("SMITH").execute(&mut db).await?; + assert_eq!(results.len(), 1, "Author search should be case insensitive"); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_multiple_papers_same_author() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let mut paper1 = create_test_paper(); + let mut paper2 = create_second_test_paper(); + + // Give both papers the same author + let author = + Author { name: "Shared Author".to_string(), affiliation: None, email: None }; + paper1.authors = vec![author.clone()]; + paper2.authors = vec![author]; + + Add::paper(&paper1).execute(&mut db).await?; + Add::paper(&paper2).execute(&mut db).await?; + + let results = Query::by_author("Shared Author").execute(&mut db).await?; + assert_eq!(results.len(), 2); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_author_with_affiliation() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let mut paper = create_test_paper(); + paper.authors = vec![Author { + name: "John Doe".to_string(), + affiliation: Some("Test University".to_string()), + email: Some("john@test.edu".to_string()), + }]; + + Add::paper(&paper).execute(&mut db).await?; + + let results = Query::by_author("John Doe").execute(&mut db).await?; + assert_eq!(results[0].authors[0].affiliation, Some("Test University".to_string())); + assert_eq!(results[0].authors[0].email, Some("john@test.edu".to_string())); + + Ok(()) + } +} + +/// Source-based search functionality +mod source_search { + use super::*; + + #[traced_test] + #[tokio::test] + async fn test_basic_source_search() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + let paper = create_test_paper(); + Add::paper(&paper).execute(&mut db).await?; + + let results = Query::by_source(Source::Arxiv, "2301.00000").execute(&mut db).await?; + assert_eq!(results.len(), 1); + assert_eq!(results[0].source_identifier, "2301.00000"); + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_multiple_sources() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let mut paper1 = create_test_paper(); + let mut paper2 = create_second_test_paper(); + paper1.source = Source::Arxiv; + paper2.source = Source::DOI; + + Add::paper(&paper1).execute(&mut db).await?; + Add::paper(&paper2).execute(&mut db).await?; + + let results = Query::list_all().order_by(OrderField::Source).execute(&mut db).await?; + assert_eq!(results.len(), 2); + assert!(results.iter().any(|p| p.source == Source::Arxiv)); + assert!(results.iter().any(|p| p.source == Source::DOI)); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_source_with_doi() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let mut paper = create_test_paper(); + paper.source = Source::DOI; + paper.source_identifier = "10.1234/test".to_string(); + paper.doi = Some(paper.source_identifier.clone()); + + Add::paper(&paper).execute(&mut db).await?; + + let results = Query::by_source(Source::DOI, "10.1234/test").execute(&mut db).await?; + assert_eq!(results.len(), 1); + assert_eq!(results[0].doi, Some("10.1234/test".to_string())); + + Ok(()) + } +} + +/// Ordering and pagination tests +mod ordering { + use super::*; + + #[traced_test] + #[tokio::test] + async fn test_date_ordering() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let paper1 = create_test_paper(); // 2023 + let paper2 = create_second_test_paper(); // 2024 + + Add::paper(&paper1).execute(&mut db).await?; + Add::paper(&paper2).execute(&mut db).await?; + + let results = Query::list_all().order_by(OrderField::PublicationDate).execute(&mut db).await?; + assert_eq!(results[0].publication_date.year(), 2023); + assert_eq!(results[1].publication_date.year(), 2024); + + let results = + Query::list_all().order_by(OrderField::PublicationDate).descending().execute(&mut db).await?; + assert_eq!(results[0].publication_date.year(), 2024); + assert_eq!(results[1].publication_date.year(), 2023); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_title_ordering() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let mut paper1 = create_test_paper(); + let mut paper2 = create_second_test_paper(); + paper1.title = "Beta Paper".to_string(); + paper2.title = "Alpha Paper".to_string(); + + Add::paper(&paper1).execute(&mut db).await?; + Add::paper(&paper2).execute(&mut db).await?; + + let results = Query::list_all().order_by(OrderField::Title).execute(&mut db).await?; + assert_eq!(results[0].title, "Alpha Paper"); + assert_eq!(results[1].title, "Beta Paper"); + + Ok(()) + } +} + +/// Edge cases and special conditions +mod edge_cases { + use super::*; + + #[traced_test] + #[tokio::test] + async fn test_empty_database() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let results = Query::list_all().execute(&mut db).await?; + assert_eq!(results.len(), 0); + + let results = Query::text("any text").execute(&mut db).await?; + assert_eq!(results.len(), 0); + + let results = Query::by_author("any author").execute(&mut db).await?; + assert_eq!(results.len(), 0); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_special_characters() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let mut paper = create_test_paper(); + // Use simpler special characters that FTS5 can handle + paper.title = "Test Paper: A Study".to_string(); + paper.authors = vec![Author { + name: "O'Connor Smith".to_string(), + affiliation: None, + email: None, + }]; + + Add::paper(&paper).execute(&mut db).await?; + + // Search with and without special characters + let results = Query::text("Test Paper").execute(&mut db).await?; + assert_eq!(results.len(), 1); + + let results = Query::text("Test").execute(&mut db).await?; + assert_eq!(results.len(), 1); + + // Author search should still work with apostrophe + let results = Query::by_author("O'Connor").execute(&mut db).await?; + assert_eq!(results.len(), 1); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn test_very_long_text() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let mut paper = create_test_paper(); + paper.title = "A ".repeat(500) + "unique marker"; + + Add::paper(&paper).execute(&mut db).await?; + + let results = Query::text("unique marker").execute(&mut db).await?; + assert_eq!(results.len(), 1); + + Ok(()) + } +} + +#[traced_test] +#[tokio::test] +async fn test_fts_behavior() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let mut paper = create_test_paper(); + paper.title = "Testing: Advanced Search & Queries".to_string(); + paper.abstract_text = "This is a complex abstract with many terms".to_string(); + Add::paper(&paper).execute(&mut db).await?; + + // Basic word search works + let results = Query::text("Testing").execute(&mut db).await?; + assert_eq!(results.len(), 1); + + // Words are tokenized properly + let results = Query::text("Advanced Search").execute(&mut db).await?; + assert_eq!(results.len(), 1); + + // Special characters are treated as word boundaries + let results = Query::text("Queries").execute(&mut db).await?; + assert_eq!(results.len(), 1); + + // Only title is searchable + let results = Query::text("complex abstract").execute(&mut db).await?; + assert_eq!(results.len(), 0); + + Ok(()) +} diff --git a/crates/learner/tests/database/instruction/remove.rs b/crates/learner/tests/database/instruction/remove.rs new file mode 100644 index 0000000..1a5ae45 --- /dev/null +++ b/crates/learner/tests/database/instruction/remove.rs @@ -0,0 +1,407 @@ +use super::*; + +/// Basic removal functionality tests +mod basic_operations { + + use super::*; + + #[tokio::test] + #[traced_test] + async fn test_remove_existing_paper() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let paper = create_test_paper(); + Add::paper(&paper).execute(&mut db).await?; + + let removed_papers = + Remove::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + + assert_eq!(removed_papers.len(), 1); + assert_eq!(removed_papers[0].title, paper.title); + assert_eq!(removed_papers[0].authors.len(), paper.authors.len()); + + let results = Query::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + assert_eq!(results.len(), 0); + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_remove_nonexistent_paper() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let removed = Remove::by_source(Source::Arxiv, "nonexistent").execute(&mut db).await?; + assert!(removed.is_empty()); + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_remove_cascades_to_authors() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let paper = create_test_paper(); + Add::paper(&paper).execute(&mut db).await?; + + Remove::from_query(Query::text("test")).execute(&mut db).await?; + let authors = Query::by_author("").execute(&mut db).await?; + + assert_eq!(authors.len(), 0); + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_remove_complete_paper() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let paper = Paper::new("https://arxiv.org/abs/2301.07041").await?; + // Add paper with document + Add::complete(&paper).execute(&mut db).await?; + + // Remove it + Remove::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + + // Verify paper is gone + let results = Query::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + assert_eq!(results.len(), 0); + + Ok(()) + } +} + +/// Dry run functionality tests +mod dry_run { + use super::*; + + #[tokio::test] + #[traced_test] + async fn test_dry_run_basic() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let paper = create_test_paper(); + Add::paper(&paper).execute(&mut db).await?; + + let would_remove = + Remove::by_source(paper.source, &paper.source_identifier).dry_run().execute(&mut db).await?; + + assert_eq!(would_remove.len(), 1); + assert_eq!(would_remove[0].title, paper.title); + + let results = Query::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + assert_eq!(results.len(), 1); + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_dry_run_returns_complete_paper() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let paper = create_test_paper(); + Add::paper(&paper).execute(&mut db).await?; + + let would_remove = Remove::from_query(Query::text("test")).dry_run().execute(&mut db).await?; + + assert_eq!(would_remove.len(), 1); + let removed = &would_remove[0]; + + // Verify all fields + assert_eq!(removed.title, paper.title); + assert_eq!(removed.abstract_text, paper.abstract_text); + assert_eq!(removed.publication_date, paper.publication_date); + assert_eq!(removed.source, paper.source); + assert_eq!(removed.source_identifier, paper.source_identifier); + assert_eq!(removed.pdf_url, paper.pdf_url); + assert_eq!(removed.doi, paper.doi); + assert_eq!(removed.authors.len(), paper.authors.len()); + + for (removed_author, original_author) in removed.authors.iter().zip(paper.authors.iter()) { + assert_eq!(removed_author.name, original_author.name); + assert_eq!(removed_author.affiliation, original_author.affiliation); + assert_eq!(removed_author.email, original_author.email); + } + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_dry_run_with_complete_paper() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let paper = Paper::new("https://arxiv.org/abs/2301.07041").await?; + Add::complete(&paper).execute(&mut db).await?; + + let would_remove = + Remove::by_source(paper.source, &paper.source_identifier).dry_run().execute(&mut db).await?; + + // Verify paper would be removed + assert_eq!(would_remove.len(), 1); + + // But verify it's still in the database + let results = Query::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + assert_eq!(results.len(), 1); + + Ok(()) + } +} + +/// Query-based removal tests +mod query_based_removal { + + use super::*; + + #[tokio::test] + #[traced_test] + async fn test_remove_by_text_search() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + Add::paper(&create_test_paper()).execute(&mut db).await?; + Add::paper(&create_second_test_paper()).execute(&mut db).await?; + + let removed = Remove::from_query(Query::text("two")).execute(&mut db).await?; + assert_eq!(removed.len(), 1); + assert_eq!(removed[0].title, "Test Paper: Two"); + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_remove_by_author() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + Add::paper(&create_test_paper()).execute(&mut db).await?; + Add::paper(&create_second_test_paper()).execute(&mut db).await?; + + let removed = Remove::from_query(Query::by_author("John Doe")).execute(&mut db).await?; + assert_eq!(removed.len(), 1); + assert_eq!(removed[0].authors[0].name, "John Doe"); + + // Verify only the matching paper was removed + let remaining = Query::list_all().execute(&mut db).await?; + assert_eq!(remaining.len(), 1); + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_remove_with_ordering() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + Add::paper(&create_test_paper()).execute(&mut db).await?; + Add::paper(&create_second_test_paper()).execute(&mut db).await?; + + let removed = + Remove::from_query(Query::text("test").order_by(OrderField::PublicationDate).descending()) + .execute(&mut db) + .await?; + + assert_eq!(removed.len(), 2); + assert_eq!(removed[0].title, "Test Paper: Two"); // More recent + assert_eq!(removed[1].title, "Test Paper"); + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_remove_by_date_range() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + Add::paper(&create_test_paper()).execute(&mut db).await?; + Add::paper(&create_second_test_paper()).execute(&mut db).await?; + + let cutoff_date = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); + let removed = + Remove::from_query(Query::before_date(cutoff_date).order_by(OrderField::PublicationDate)) + .execute(&mut db) + .await?; + + assert_eq!(removed.len(), 1); + assert_eq!(removed[0].title, "Test Paper"); + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_remove_multiple_papers_by_source() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + // Add multiple papers from same source + let paper1 = create_test_paper(); + let paper2 = create_second_test_paper(); + Add::paper(&paper1).execute(&mut db).await?; + Add::paper(&paper2).execute(&mut db).await?; + + // Use a text search that will match all papers from this source + // alternatively we could use Query::list_all() with a source filter + let removed = Remove::from_query(Query::text("test")).execute(&mut db).await?; + assert_eq!(removed.len(), 2); + assert!(removed.iter().all(|p| p.source == Source::Arxiv)); + + // Verify all papers are gone + let remaining = Query::text("test").execute(&mut db).await?; + assert!(remaining.is_empty()); + + Ok(()) + } + + // Alternative version using list_all + #[tokio::test] + #[traced_test] + async fn test_remove_multiple_papers_from_source() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + // Add papers from different sources + let paper1 = create_test_paper(); + let paper2 = create_second_test_paper(); + Add::paper(&paper1).execute(&mut db).await?; + Add::paper(&paper2).execute(&mut db).await?; + + // Verify we have papers from our source before removal + let initial = Query::list_all().execute(&mut db).await?; + assert!(initial.iter().any(|p| p.source == Source::Arxiv)); + + // Remove all papers using list_all and checking source + let removed = + Remove::from_query(Query::list_all().order_by(OrderField::Source)).execute(&mut db).await?; + + // Count papers from our source + let arxiv_count = removed.iter().filter(|p| p.source == Source::Arxiv).count(); + assert_eq!(arxiv_count, 2); + + // Verify no papers remain from that source + let remaining = Query::list_all().execute(&mut db).await?; + assert!(!remaining.iter().any(|p| p.source == Source::Arxiv)); + + Ok(()) + } +} + +/// Recovery and data integrity tests +mod recovery { + use super::*; + + #[tokio::test] + #[traced_test] + async fn test_remove_papers_can_be_readded() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let paper = create_test_paper(); + Add::paper(&paper).execute(&mut db).await?; + + let removed_papers = Remove::from_query(Query::text("test")).execute(&mut db).await?; + assert_eq!(removed_papers.len(), 1); + + Add::paper(&removed_papers[0]).execute(&mut db).await?; + + let results = Query::text("test").execute(&mut db).await?; + assert_eq!(results.len(), 1); + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_bulk_remove_and_readd() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + Add::paper(&create_test_paper()).execute(&mut db).await?; + Add::paper(&create_second_test_paper()).execute(&mut db).await?; + + let removed = Remove::from_query(Query::text("test")).execute(&mut db).await?; + assert_eq!(removed.len(), 2); + + for paper in &removed { + Add::paper(paper).execute(&mut db).await?; + } + + let results = Query::text("test").execute(&mut db).await?; + assert_eq!(results.len(), 2); + + // Verify order is preserved + assert_eq!(results[0].title, removed[0].title); + assert_eq!(results[1].title, removed[1].title); + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_readd_with_different_completion() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + // Add paper without document + let paper = Paper::new("https://arxiv.org/abs/2301.07041").await?; + Add::paper(&paper).execute(&mut db).await?; + + // Remove it + let removed = + Remove::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + + // Readd with document + Add::complete(&removed[0]).execute(&mut db).await?; + + // Verify paper exists with updated data + let results = Query::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + assert_eq!(results.len(), 1); + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_remove_and_readd_preserves_metadata() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let mut paper = create_test_paper(); + // Add some optional fields + paper.doi = Some("10.1234/test".to_string()); + paper.pdf_url = Some("https://example.com/test.pdf".to_string()); + + Add::paper(&paper).execute(&mut db).await?; + + let removed = + Remove::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + Add::paper(&removed[0]).execute(&mut db).await?; + + let results = Query::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + assert_eq!(results[0].doi, paper.doi); + assert_eq!(results[0].pdf_url, paper.pdf_url); + + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_remove_readd_with_updated_data() -> TestResult<()> { + let (mut db, _dir) = setup_test_db().await; + + let paper = create_test_paper(); + Add::paper(&paper).execute(&mut db).await?; + + let mut removed = + Remove::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + + // Modify the removed paper + let mut updated_paper = removed.remove(0); + updated_paper.abstract_text = "Updated abstract".to_string(); + updated_paper.doi = Some("10.1234/new".to_string()); + + // Readd with changes + Add::paper(&updated_paper).execute(&mut db).await?; + + let results = Query::by_source(paper.source, &paper.source_identifier).execute(&mut db).await?; + assert_eq!(results[0].abstract_text, "Updated abstract"); + assert_eq!(results[0].doi, Some("10.1234/new".to_string())); + + Ok(()) + } +} diff --git a/crates/learner/tests/database/mod.rs b/crates/learner/tests/database/mod.rs index 1c97d23..5f4b14e 100644 --- a/crates/learner/tests/database/mod.rs +++ b/crates/learner/tests/database/mod.rs @@ -1,4 +1,3 @@ use super::*; -mod find; -mod save; +mod instruction; diff --git a/crates/learner/tests/database/save.rs b/crates/learner/tests/database/save.rs deleted file mode 100644 index ebe2776..0000000 --- a/crates/learner/tests/database/save.rs +++ /dev/null @@ -1,110 +0,0 @@ -use super::*; - -#[traced_test] -#[tokio::test] -async fn test_save_and_retrieve_paper() { - let (db, _dir) = setup_test_db().await; - let paper = create_test_paper(); - - // Save paper - let paper_id = db.save_paper(&paper).await.unwrap(); - assert!(paper_id > 0); - - // Retrieve paper - let retrieved = db - .get_paper_by_source_id(&paper.source, &paper.source_identifier) - .await - .unwrap() - .expect("Paper should exist"); - - // Verify paper data - assert_eq!(retrieved.title, paper.title); - assert_eq!(retrieved.abstract_text, paper.abstract_text); - assert_eq!(retrieved.publication_date, paper.publication_date); - assert_eq!(retrieved.source, paper.source); - assert_eq!(retrieved.source_identifier, paper.source_identifier); - assert_eq!(retrieved.pdf_url, paper.pdf_url); - assert_eq!(retrieved.doi, paper.doi); - - // Verify authors - assert_eq!(retrieved.authors.len(), paper.authors.len()); - assert_eq!(retrieved.authors[0].name, paper.authors[0].name); - assert_eq!(retrieved.authors[0].affiliation, paper.authors[0].affiliation); - assert_eq!(retrieved.authors[0].email, paper.authors[0].email); - assert_eq!(retrieved.authors[1].name, paper.authors[1].name); - assert_eq!(retrieved.authors[1].affiliation, None); - assert_eq!(retrieved.authors[1].email, None); -} - -#[traced_test] -#[tokio::test] -async fn test_duplicate_paper_handling() { - let (db, _dir) = setup_test_db().await; - let paper = create_test_paper(); - - // Save paper first time - let result1 = db.save_paper(&paper).await; - assert!(result1.is_ok()); - - // Try to save the same paper again - let result2 = db.save_paper(&paper).await; - assert!(result2.is_err()); // Should fail due to UNIQUE constraint -} - -#[traced_test] -#[tokio::test] -async fn test_pdf_recording() { - let (db, _dir) = setup_test_db().await; - let paper = create_test_paper(); - - // Save paper first to get an ID - let paper_id = db.save_paper(&paper).await.unwrap(); - - // Test recording successful PDF download - let path = PathBuf::from("/test/path/paper.pdf"); - let filename = "paper.pdf".to_string(); - - let file_id = - db.record_pdf(paper_id, path.clone(), filename.clone(), "success", None).await.unwrap(); - - assert!(file_id > 0); - - // Test retrieving PDF status - let status = db.get_pdf_status(paper_id).await.unwrap(); - assert!(status.is_some()); - - let (stored_path, stored_filename, stored_status, error) = status.unwrap(); - assert_eq!(stored_path, path); - assert_eq!(stored_filename, filename); - assert_eq!(stored_status, "success"); - assert_eq!(error, None); -} - -#[traced_test] -#[tokio::test] -async fn test_pdf_failure_recording() { - let (db, _dir) = setup_test_db().await; - let paper = create_test_paper(); - - // Save paper first to get an ID - let paper_id = db.save_paper(&paper).await.unwrap(); - - // Test recording failed PDF download - let path = PathBuf::from("/test/path/paper.pdf"); - let filename = "paper.pdf".to_string(); - let error_msg = "HTTP 403: Access Denied".to_string(); - - db.record_pdf(paper_id, path.clone(), filename.clone(), "failed", Some(error_msg.clone())) - .await - .unwrap(); - - // Test retrieving failed status - let status = db.get_pdf_status(paper_id).await.unwrap(); - assert!(status.is_some()); - - let (stored_path, stored_filename, stored_status, error) = status.unwrap(); - assert_eq!(stored_path, path); - assert_eq!(stored_filename, filename); - assert_eq!(stored_status, "failed"); - assert_eq!(error, Some(error_msg)); -} diff --git a/crates/learner/tests/lib.rs b/crates/learner/tests/lib.rs index f617589..319ca5b 100644 --- a/crates/learner/tests/lib.rs +++ b/crates/learner/tests/lib.rs @@ -1,11 +1,12 @@ -use std::{error::Error, path::PathBuf}; +use std::{error::Error, path::PathBuf, str::FromStr}; use learner::{ - database::Database, + error::LearnerError, format, llm::{LlamaRequest, Model}, - paper::{Author, Paper, Source}, + paper::{Author, Paper}, pdf::PDFContentBuilder, + prelude::*, }; use tempfile::{tempdir, TempDir}; use tracing_test::traced_test; @@ -13,32 +14,4 @@ use tracing_test::traced_test; mod database; mod llm; -/// Helper function to set up a test database -async fn setup_test_db() -> (Database, TempDir) { - let dir = tempdir().unwrap(); - let db_path = dir.path().join("test.db"); - let db = Database::open(&db_path).await.unwrap(); - (db, dir) -} - -/// Helper function to create a test paper -fn create_test_paper() -> Paper { - Paper { - title: "Test Paper".to_string(), - abstract_text: "This is a test abstract".to_string(), - publication_date: chrono::TimeZone::with_ymd_and_hms(&chrono::Utc, 2024, 1, 1, 0, 0, 0) - .unwrap(), - source: Source::Arxiv, - source_identifier: "2401.00000".to_string(), - pdf_url: Some("https://arxiv.org/pdf/2401.00000".to_string()), - doi: Some("10.1000/test.123".to_string()), - authors: vec![ - Author { - name: "John Doe".to_string(), - affiliation: Some("Test University".to_string()), - email: Some("john@test.edu".to_string()), - }, - Author { name: "Jane Smith".to_string(), affiliation: None, email: None }, - ], - } -} +pub type TestResult = Result>; diff --git a/crates/learner/tests/llm/mod.rs b/crates/learner/tests/llm/mod.rs index acf513a..c14b4a7 100644 --- a/crates/learner/tests/llm/mod.rs +++ b/crates/learner/tests/llm/mod.rs @@ -7,7 +7,7 @@ async fn test_download_then_send_pdf() -> Result<(), Box> { // Download a PDF let dir = tempdir().unwrap(); let paper = Paper::new("https://eprint.iacr.org/2016/260").await.unwrap(); - paper.download_pdf(dir.path().to_path_buf()).await.unwrap(); + paper.download_pdf(dir.path()).await.unwrap(); // Get the content of the PDF let formatted_title = format::format_title(&paper.title, None); // use default 50 diff --git a/crates/learnerd/src/commands/add.rs b/crates/learnerd/src/commands/add.rs index 207e72a..90bb213 100644 --- a/crates/learnerd/src/commands/add.rs +++ b/crates/learnerd/src/commands/add.rs @@ -14,7 +14,7 @@ pub async fn add(cli: Cli, identifier: String, no_pdf: bool) -> Result<()> { default_path }); trace!("Using database at: {}", path.display()); - let db = Database::open(&path).await?; + let mut db = Database::open(&path).await?; println!("{} Fetching paper: {}", style(LOOKING_GLASS).cyan(), style(&identifier).yellow()); @@ -29,9 +29,15 @@ pub async fn add(cli: Cli, identifier: String, no_pdf: bool) -> Result<()> { style(paper.authors.iter().map(|a| a.name.as_str()).collect::>().join(", ")).white() ); - match paper.save(&db).await { - Ok(id) => { - println!("\n{} Saved paper with ID: {}", style(SAVE).green(), style(id).yellow()); + // TODO (autoparallel): This flow could be refactored now with the `Add::complete` to make it + // easier. + match Add::paper(&paper).execute(&mut db).await { + Ok(papers) => { + println!( + "\n{} Saved paper with ID: {}", + style(SAVE).green(), + style(papers[0].source_identifier.clone()).yellow() + ); // Handle PDF download for newly added paper if paper.pdf_url.is_some() && !no_pdf { @@ -44,19 +50,10 @@ pub async fn add(cli: Cli, identifier: String, no_pdf: bool) -> Result<()> { if should_download { println!("{} Downloading PDF...", style(LOOKING_GLASS).cyan()); - let pdf_dir = match db.get_config("pdf_dir").await? { - Some(dir) => PathBuf::from(dir), - None => { - println!( - "{} PDF directory not configured. Run {} first", - style(WARNING).yellow(), - style("learnerd init").cyan() - ); - return Ok(()); - }, - }; + let _pdf_dir = db.get_storage_path().await?; - match paper.download_pdf(pdf_dir).await { + // TODO: Don't use this direct download. + match Add::complete(&paper).execute(&mut db).await { Ok(_) => { println!("{} PDF downloaded successfully!", style(SUCCESS).green()); }, @@ -80,69 +77,68 @@ pub async fn add(cli: Cli, identifier: String, no_pdf: bool) -> Result<()> { println!("\n{} No PDF URL available for this paper", style(WARNING).yellow()); } }, - Err(e) if e.is_duplicate_error() => { - println!("\n{} This paper is already in your database", style("ℹ").blue()); - - // Check existing PDF status - if paper.pdf_url.is_some() && !no_pdf { - if let Ok(Some(dir)) = db.get_config("pdf_dir").await { - let pdf_dir = PathBuf::from(dir); - let formatted_title = learner::format::format_title(&paper.title, Some(50)); - let pdf_path = pdf_dir.join(format!("{}.pdf", formatted_title)); - - if pdf_path.exists() { - println!( - " {} PDF exists at: {}", - style("📄").cyan(), - style(pdf_path.display()).yellow() - ); - - let should_redownload = if cli.accept_defaults { - false // Default to not redownloading in automated mode - } else { - dialoguer::Confirm::new() - .with_prompt("Download fresh copy? (This will overwrite the existing file)") - .default(false) - .interact()? - }; + Err(e) => + if let LearnerError::DatabaseDuplicatePaper(_) = e { + println!("\n{} This paper is already in your database", style("ℹ").blue()); + + // Check existing PDF status + if paper.pdf_url.is_some() && !no_pdf { + if let Ok(pdf_dir) = db.get_storage_path().await { + let pdf_path = pdf_dir.join(paper.filename()); + if pdf_path.exists() { + println!( + " {} PDF exists at: {}", + style("📄").cyan(), + style(pdf_path.display()).yellow() + ); - if should_redownload { - println!("{} Downloading fresh copy of PDF...", style(LOOKING_GLASS).cyan()); - match paper.download_pdf(pdf_dir).await { - Ok(_) => println!("{} PDF downloaded successfully!", style(SUCCESS).green()), - Err(e) => println!( - "{} Failed to download PDF: {}", - style(WARNING).yellow(), - style(e.to_string()).red() - ), + let should_redownload = if cli.accept_defaults { + false // Default to not redownloading in automated mode + } else { + dialoguer::Confirm::new() + .with_prompt("Download fresh copy? (This will overwrite the existing file)") + .default(false) + .interact()? + }; + + if should_redownload { + println!("{} Downloading fresh copy of PDF...", style(LOOKING_GLASS).cyan()); + match paper.download_pdf(&pdf_dir).await { + Ok(_) => println!("{} PDF downloaded successfully!", style(SUCCESS).green()), + Err(e) => println!( + "{} Failed to download PDF: {}", + style(WARNING).yellow(), + style(e.to_string()).red() + ), + } } - } - } else { - let should_download = if cli.accept_defaults { - true // Default to downloading in automated mode } else { - dialoguer::Confirm::new() - .with_prompt("PDF not found. Download it now?") - .default(true) - .interact()? - }; - - if should_download { - println!("{} Downloading PDF...", style(LOOKING_GLASS).cyan()); - match paper.download_pdf(pdf_dir).await { - Ok(_) => println!("{} PDF downloaded successfully!", style(SUCCESS).green()), - Err(e) => println!( - "{} Failed to download PDF: {}", - style(WARNING).yellow(), - style(e.to_string()).red() - ), + let should_download = if cli.accept_defaults { + true // Default to downloading in automated mode + } else { + dialoguer::Confirm::new() + .with_prompt("PDF not found. Download it now?") + .default(true) + .interact()? + }; + + if should_download { + println!("{} Downloading PDF...", style(LOOKING_GLASS).cyan()); + match paper.download_pdf(&pdf_dir).await { + Ok(_) => println!("{} PDF downloaded successfully!", style(SUCCESS).green()), + Err(e) => println!( + "{} Failed to download PDF: {}", + style(WARNING).yellow(), + style(e.to_string()).red() + ), + } } } } + } else { + return Err(LearnerdError::from(e)); } - } - }, - Err(e) => return Err(LearnerdError::Learner(e)), + }, } Ok(()) diff --git a/crates/learnerd/src/commands/download.rs b/crates/learnerd/src/commands/download.rs index 068f05e..4daeb49 100644 --- a/crates/learnerd/src/commands/download.rs +++ b/crates/learnerd/src/commands/download.rs @@ -13,49 +13,25 @@ pub async fn download(cli: Cli, source: Source, identifier: String) -> Result<() ); default_path }); - let db = Database::open(&path).await?; + let mut db = Database::open(&path).await?; - let paper = match db.get_paper_by_source_id(&source, &identifier).await? { - Some(p) => p, - None => { - println!( - "{} Paper not found in database. Add it first with: {} {}", - style(WARNING).yellow(), - style("learnerd add").yellow(), - style(&identifier).cyan() - ); - return Ok(()); - }, - }; + let papers = Query::by_source(source, &identifier).execute(&mut db).await?; + if papers.is_empty() { + println!( + "{} Paper not found in database. Add it first with: {} {}", + style(WARNING).yellow(), + style("learnerd add").yellow(), + style(&identifier).cyan() + ); + return Ok(()); + } - if paper.pdf_url.is_none() { + if papers[0].pdf_url.is_none() { println!("{} No PDF URL available for this paper", style(WARNING).yellow()); return Ok(()); }; - let pdf_dir = match db.get_config("pdf_dir").await? { - Some(dir) => PathBuf::from(dir), - None => { - println!( - "{} PDF directory not configured. Run {} first", - style(WARNING).yellow(), - style("learnerd init").cyan() - ); - return Ok(()); - }, - }; - - if !pdf_dir.exists() { - println!( - "{} Creating PDF directory: {}", - style(LOOKING_GLASS).cyan(), - style(&pdf_dir.display()).yellow() - ); - std::fs::create_dir_all(&pdf_dir)?; - } - - let formatted_title = learner::format::format_title(&paper.title, Some(50)); - let pdf_path = pdf_dir.join(format!("{}.pdf", formatted_title)); + let pdf_path = db.get_storage_path().await?.join(papers[0].filename()); let should_download = if pdf_path.exists() && !cli.accept_defaults { println!( @@ -79,7 +55,7 @@ pub async fn download(cli: Cli, source: Source, identifier: String) -> Result<() println!("{} Downloading PDF...", style(LOOKING_GLASS).cyan()); } - match paper.download_pdf(pdf_dir.clone()).await { + match Add::complete(&papers[0]).execute(&mut db).await { Ok(_) => { println!("{} PDF downloaded successfully!", style(SUCCESS).green()); println!(" {} Saved to: {}", style("📄").cyan(), style(&pdf_path.display()).yellow()); @@ -104,9 +80,9 @@ pub async fn download(cli: Cli, source: Source, identifier: String) -> Result<() }, LearnerError::Path(_) => { println!( - " {} Check if you have write permissions for: {}", + " {} Check if you have write permissions for: {:?}", style("Tip:").blue(), - style(&pdf_dir.display()).yellow() + style(db.get_storage_path().await?).yellow() ); }, _ => { diff --git a/crates/learnerd/src/commands/get.rs b/crates/learnerd/src/commands/get.rs index 4ee5795..c16c539 100644 --- a/crates/learnerd/src/commands/get.rs +++ b/crates/learnerd/src/commands/get.rs @@ -14,7 +14,7 @@ pub async fn get(cli: Cli, source: Source, identifier: String) -> Result<()> { default_path }); trace!("Using database at: {}", path.display()); - let db = Database::open(&path).await?; + let mut db = Database::open(&path).await?; println!( "{} Fetching paper from {} with ID {}", @@ -23,7 +23,8 @@ pub async fn get(cli: Cli, source: Source, identifier: String) -> Result<()> { style(&identifier).yellow() ); - match db.get_paper_by_source_id(&source, &identifier).await? { + let papers = Query::by_source(source, &identifier).execute(&mut db).await?; + match papers.first() { Some(paper) => { debug!("Found paper: {:?}", paper); println!("\n{} Paper details:", style(PAPER).green()); diff --git a/crates/learnerd/src/commands/init.rs b/crates/learnerd/src/commands/init.rs index 84759ba..98f5dc4 100644 --- a/crates/learnerd/src/commands/init.rs +++ b/crates/learnerd/src/commands/init.rs @@ -81,7 +81,7 @@ pub async fn init(cli: Cli) -> Result<()> { let db = Database::open(&db_path).await?; // Set up PDF directory - let pdf_dir = Database::default_pdf_path(); + let pdf_dir = Database::default_storage_path(); println!( "\n{} PDF files will be stored in: {}", style(PAPER).cyan(), @@ -106,7 +106,7 @@ pub async fn init(cli: Cli) -> Result<()> { }; std::fs::create_dir_all(&pdf_dir)?; - db.set_config("pdf_dir", &pdf_dir.to_string_lossy()).await?; + db.set_storage_path(&pdf_dir.to_string_lossy().to_string()).await?; println!("{} Database initialized successfully!", style(SUCCESS).green()); Ok(()) diff --git a/crates/learnerd/src/commands/mod.rs b/crates/learnerd/src/commands/mod.rs index 7123aaa..639a723 100644 --- a/crates/learnerd/src/commands/mod.rs +++ b/crates/learnerd/src/commands/mod.rs @@ -62,14 +62,12 @@ pub mod init; pub mod remove; pub mod search; -pub use add::add; -pub use clean::clean; -pub use daemon::daemon; -pub use download::download; -pub use get::get; -pub use init::init; -pub use remove::remove; -pub use search::search; +use learner::database::{Add, Query}; + +pub use self::{ + add::add, clean::clean, daemon::daemon, download::download, get::get, init::init, remove::remove, + search::search, +}; /// Available commands for the CLI #[derive(Subcommand, Clone)] diff --git a/crates/learnerd/src/commands/search.rs b/crates/learnerd/src/commands/search.rs index faa6e3c..1c4918a 100644 --- a/crates/learnerd/src/commands/search.rs +++ b/crates/learnerd/src/commands/search.rs @@ -14,7 +14,7 @@ pub async fn search(cli: Cli, query: String) -> Result<()> { default_path }); trace!("Using database at: {}", path.display()); - let db = Database::open(&path).await?; + let mut db = Database::open(&path).await?; println!("{} Searching for: {}", style(LOOKING_GLASS).cyan(), style(&query).yellow()); @@ -22,7 +22,7 @@ pub async fn search(cli: Cli, query: String) -> Result<()> { let search_query = query.split_whitespace().collect::>().join(" OR "); debug!("Modified search query: {}", search_query); - let papers = db.search_papers(&search_query).await?; + let papers = Query::text(&search_query).execute(&mut db).await?; if papers.is_empty() { println!("{} No papers found matching: {}", style(WARNING).yellow(), style(&query).yellow()); } else { diff --git a/crates/learnerd/src/main.rs b/crates/learnerd/src/main.rs index bd9f8ad..84ddb35 100644 --- a/crates/learnerd/src/main.rs +++ b/crates/learnerd/src/main.rs @@ -41,6 +41,7 @@ use learner::{ database::Database, error::LearnerError, paper::{Paper, Source}, + prelude::*, }; use tracing::{debug, trace}; use tracing_subscriber::EnvFilter; @@ -109,9 +110,10 @@ pub struct Cli { /// - 3+: trace fn setup_logging(verbosity: u8) { let filter = match verbosity { - 0 => "warn", - 1 => "info", - 2 => "debug", + 0 => "error", + 1 => "warn", + 2 => "info", + 3 => "debug", _ => "trace", }; diff --git a/crates/learnerd/src/tui/mod.rs b/crates/learnerd/src/tui/mod.rs index 8a4b9e5..985d503 100644 --- a/crates/learnerd/src/tui/mod.rs +++ b/crates/learnerd/src/tui/mod.rs @@ -45,7 +45,10 @@ use crossterm::{ execute, terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, }; -use learner::{database::Database, format::format_title}; +use learner::{ + database::{OrderField, Query}, + format::format_title, +}; use ratatui::{backend::CrosstermBackend, Terminal}; use super::*; @@ -100,8 +103,8 @@ use ui::UIDrawer; /// - Disables mouse capture pub async fn run() -> Result<()> { // Initialize state - let db = Database::open(Database::default_path()).await?; - let papers = db.list_papers("title", false).await?; + let mut db = Database::open(Database::default_path()).await?; + let papers = Query::list_all().order_by(OrderField::Title).execute(&mut db).await?; let mut state = UIState::new(papers); // Setup terminal diff --git a/crates/learnerd/src/tui/state.rs b/crates/learnerd/src/tui/state.rs index 77c992d..b2739a7 100644 --- a/crates/learnerd/src/tui/state.rs +++ b/crates/learnerd/src/tui/state.rs @@ -225,7 +225,7 @@ impl UIState { if let Some(paper) = self.selected_paper() { let pdf_path = format!( "{}/{}.pdf", - Database::default_pdf_path().display(), + Database::default_storage_path().display(), format_title(&paper.title, Some(50)) ); diff --git a/crates/learnerd/src/tui/ui.rs b/crates/learnerd/src/tui/ui.rs index 87abeba..a052387 100644 --- a/crates/learnerd/src/tui/ui.rs +++ b/crates/learnerd/src/tui/ui.rs @@ -284,7 +284,7 @@ impl<'a, 'b> UIDrawer<'a, 'b> { fn draw_pdf_status(&mut self, paper: &learner::paper::Paper, area: Rect) { let pdf_path = format!( "{}/{}.pdf", - Database::default_pdf_path().display(), + Database::default_storage_path().display(), format_title(&paper.title, Some(50)) ); let pdf_exists = std::path::Path::new(&pdf_path).exists(); diff --git a/justfile b/justfile index cde0161..9aad72c 100644 --- a/justfile +++ b/justfile @@ -90,8 +90,10 @@ build-linux: # Run the tests on your local OS test: - @just header "Running tests" - cargo test --workspace --all-targets + @just header "Running main test suite" + cargo test --workspace --all-targets --all-features + @just header "Running doc tests" + cargo test --workspace --doc # Run clippy for the workspace on your local OS lint: