diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fd009ba..9218247 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,7 +25,6 @@ jobs: pip install -U pip wheel coverage coveralls pip install . python -c "import nltk; nltk.download('punkt')" - pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.5/en_core_sci_md-0.2.5.tar.gz python --version make data coverage env: diff --git a/src/python/paperetl/cord19/__main__.py b/src/python/paperetl/cord19/__main__.py index 0378dbd..59cde5b 100644 --- a/src/python/paperetl/cord19/__main__.py +++ b/src/python/paperetl/cord19/__main__.py @@ -7,11 +7,10 @@ from .execute import Execute if __name__ == "__main__": - if len(sys.argv) > 1: + if len(sys.argv) > 2: Execute.run( sys.argv[1], - sys.argv[2] if len(sys.argv) > 2 else None, + sys.argv[2], sys.argv[3] if len(sys.argv) > 3 else None, - sys.argv[4] == "True" if len(sys.argv) > 4 else True, - sys.argv[5] if len(sys.argv) > 5 else None, + sys.argv[4] == "True" if len(sys.argv) > 4 else False, ) diff --git a/src/python/paperetl/cord19/execute.py b/src/python/paperetl/cord19/execute.py index 7831301..eec8529 100644 --- a/src/python/paperetl/cord19/execute.py +++ b/src/python/paperetl/cord19/execute.py @@ -128,14 +128,13 @@ def getTags(sections): return tags @staticmethod - def stream(indir, dates, merge): + def stream(indir, dates): """ Generator that yields rows from a metadata.csv file. The directory is also included. Args: indir: input directory dates: list of uid - entry dates for current metadata file - merge: only merges/processes this list of uids, if enabled """ # Filter out duplicate ids @@ -152,15 +151,9 @@ def stream(indir, dates, merge): sha = Execute.getHash(row) # Only process if all conditions below met: - # - Merge set to None (must check for None as merge can be an empty set) or uid in list of ids to merge # - cord uid in entry date mapping # - cord uid and sha hash not already processed - if ( - (merge is None or uid in merge) - and uid in dates - and uid not in ids - and sha not in hashes - ): + if uid in dates and uid not in ids and sha not in hashes: yield (row, indir) # Add uid and sha as processed @@ -210,7 +203,7 @@ def process(params): Execute.getUrl(row), ) - return Article(metadata, sections, None) + return Article(metadata, sections) @staticmethod def entryDates(indir, entryfile): @@ -251,12 +244,12 @@ def entryDates(indir, entryfile): # Store date if cord uid maps to value in entries if row["cord_uid"] == uid: - dates[uid] = date + dates[uid] = parser.parse(date) return dates @staticmethod - def run(indir, url, entryfile, full, merge): + def run(indir, url, entryfile=None, replace=False): """ Main execution method. @@ -264,43 +257,36 @@ def run(indir, url, entryfile, full, merge): indir: input directory url: database url entryfile: path to entry dates file - full: full database load if True, only loads tagged articles if False - merge: database url to use for merging prior results + replace: if true, a new database will be created, overwriting any existing database """ print(f"Building articles database from {indir}") - # Set database url - if not url: - url = os.path.join(os.path.expanduser("~"), ".cord19", "models") - # Create database - db = Factory.create(url) + db = Factory.create(url, replace) # Load entry dates dates = Execute.entryDates(indir, entryfile) - # Merge existing db, if present - if merge: - merge = db.merge(merge, dates) - print("Merged results from existing articles database") - # Create process pool with Pool(os.cpu_count()) as pool: for article in pool.imap( - Execute.process, Execute.stream(indir, dates, merge), 100 + Execute.process, Execute.stream(indir, dates), 100 ): # Get unique id uid = article.uid() # Only load untagged rows if this is a full database load - if full or article.tags(): + if article.tags(): # Append entry date article.metadata = article.metadata + (dates[uid],) # Save article db.save(article) + pool.close() + pool.join() + # Complete processing db.complete() diff --git a/src/python/paperetl/database.py b/src/python/paperetl/database.py index 2c2e67a..c1222b0 100644 --- a/src/python/paperetl/database.py +++ b/src/python/paperetl/database.py @@ -8,22 +8,6 @@ class Database: Defines data structures and methods to store article content. """ - # pylint: disable=W0613 - def merge(self, url, ids): - """ - Merges the results of an existing database into the current database. This method returns - a list of ids not merged, which means there is a newer version available in the source data. - - Args: - url: database connection - ids: dict of id - entry date - - Returns: - list of eligible ids NOT merged - """ - - return [] - def save(self, article): """ Saves an article. diff --git a/src/python/paperetl/elastic.py b/src/python/paperetl/elastic.py index 5b805fa..5c6d23f 100644 --- a/src/python/paperetl/elastic.py +++ b/src/python/paperetl/elastic.py @@ -22,12 +22,13 @@ class Elastic(Database): "mappings": {"properties": {"sections": {"type": "nested"}}}, } - def __init__(self, url): + def __init__(self, url, replace): """ Connects and initializes an elasticsearch instance. Args: url: elasticsearch url + replace: If database should be recreated """ # Connect to ES instance @@ -39,8 +40,16 @@ def __init__(self, url): # Buffered actions self.buffer = [] - # Create index if it doesn't exist - if not self.connection.indices.exists("articles"): + # Check if index exists + exists = self.connection.indices.exists("articles") + + # Delete if replace enabled + if exists and replace: + self.connection.indices.delete("articles") + exists = False + + # Create if necessary + if not exists: self.connection.indices.create("articles", Elastic.ARTICLES) def save(self, article): diff --git a/src/python/paperetl/factory.py b/src/python/paperetl/factory.py index 6febb05..9018903 100644 --- a/src/python/paperetl/factory.py +++ b/src/python/paperetl/factory.py @@ -13,25 +13,26 @@ class Factory: """ @staticmethod - def create(url): + def create(url, replace): """ Creates a new database connection. Args: url: connection url + replace: if true, a new database will be created, overwriting any existing database Returns: Database """ if url.startswith("http://"): - return Elastic(url) + return Elastic(url, replace) if url.startswith("json://"): return JSON(url.replace("json://", "")) if url.startswith("yaml://"): return YAML(url.replace("yaml://", "")) if url: # If URL is present, assume it's SQLite - return SQLite(url.replace("sqlite://", "")) + return SQLite(url.replace("sqlite://", ""), replace) return None diff --git a/src/python/paperetl/file/__main__.py b/src/python/paperetl/file/__main__.py index d04b65e..59cde5b 100644 --- a/src/python/paperetl/file/__main__.py +++ b/src/python/paperetl/file/__main__.py @@ -9,5 +9,8 @@ if __name__ == "__main__": if len(sys.argv) > 2: Execute.run( - sys.argv[1], sys.argv[2], sys.argv[3] if len(sys.argv) > 3 else None + sys.argv[1], + sys.argv[2], + sys.argv[3] if len(sys.argv) > 3 else None, + sys.argv[4] == "True" if len(sys.argv) > 4 else False, ) diff --git a/src/python/paperetl/file/arx.py b/src/python/paperetl/file/arx.py new file mode 100644 index 0000000..01abd5a --- /dev/null +++ b/src/python/paperetl/file/arx.py @@ -0,0 +1,166 @@ +""" +arXiv XML processing module +""" + +import hashlib +import re + +from bs4 import BeautifulSoup +from dateutil import parser +from nltk.tokenize import sent_tokenize + +from ..schema.article import Article +from ..text import Text + + +class ARX: + """ + Methods to transform arXiv XML into article objects. + """ + + @staticmethod + def parse(stream, source): + """ + Parses a XML datastream and yields processed articles. + + Args: + stream: handle to input data stream + source: text string describing stream source, can be None + config: path to config directory + """ + + # Parse XML + soup = BeautifulSoup(stream, "lxml") + + # Process each entry + for entry in soup.find_all("entry"): + reference = ARX.get(entry, "id") + title = ARX.get(entry, "title") + published = parser.parse(ARX.get(entry, "published").split("T")[0]) + updated = parser.parse(ARX.get(entry, "updated").split("T")[0]) + + # Derive uid + uid = hashlib.sha1(reference.encode("utf-8")).hexdigest() + + # Get journal reference + journal = ARX.get(entry, "arxiv:journal_ref") + + # Get authors + authors, affiliations, affiliation = ARX.authors(entry.find_all("author")) + + # Get tags + tags = "; ".join( + ["ARX"] + + [category.get("term") for category in entry.find_all("category")] + ) + + # Transform section text + sections = ARX.sections(title, ARX.get(entry, "summary")) + + # Article metadata - id, source, published, publication, authors, affiliations, affiliation, title, + # tags, reference, entry date + metadata = ( + uid, + source, + published, + journal, + authors, + affiliations, + affiliation, + title, + tags, + reference, + updated, + ) + + yield Article(metadata, sections) + + @staticmethod + def get(element, path): + """ + Finds the first matching path in element and returns the element text. + + Args: + element: XML element + path: path expression + + Returns: + string + """ + + element = element.find(path) + return ARX.clean(element.text) if element else None + + @staticmethod + def clean(text): + """ + Removes newlines and extra spacing from text. + + Args: + text: text to clean + + Returns: + clean text + """ + + # Remove newlines and cleanup spacing + text = text.replace("\n", " ") + return re.sub(r"\s+", " ", text).strip() + + @staticmethod + def authors(elements): + """ + Parses authors and associated affiliations from the article. + + Args: + elements: authors elements + + Returns: + (semicolon separated list of authors, semicolon separated list of affiliations, primary affiliation) + """ + + authors = [] + affiliations = [] + + for author in elements: + # Create authors as lastname, firstname + name = ARX.get(author, "name") + authors.append(", ".join(name.rsplit(maxsplit=1)[::-1])) + + # Add affiliations + affiliations.extend( + [ + ARX.clean(affiliation.text) + for affiliation in author.find_all("arxiv:affiliation") + ] + ) + + return ( + "; ".join(authors), + "; ".join(dict.fromkeys(affiliations)), + affiliations[-1] if affiliations else None, + ) + + @staticmethod + def sections(title, text): + """ + Gets a list of sections for this article. + + Args: + title: title string + text: summary text + + Returns: + list of sections + """ + + # Add title + sections = [("TITLE", title)] + + # Transform and clean text + text = Text.transform(text) + + # Split text into sentences, transform text and add to sections + sections.extend([("ABSTRACT", x) for x in sent_tokenize(text)]) + + return sections diff --git a/src/python/paperetl/file/csvf.py b/src/python/paperetl/file/csvf.py index 3d1e4a2..7704896 100644 --- a/src/python/paperetl/file/csvf.py +++ b/src/python/paperetl/file/csvf.py @@ -5,6 +5,8 @@ import csv import datetime +from dateutil import parser + from ..schema.article import Article @@ -30,7 +32,7 @@ def parse(stream, source): # Parse sections sections = CSV.sections(row) - yield Article(metadata, sections, source) + yield Article(metadata, sections) @staticmethod def metadata(row, source): @@ -67,7 +69,11 @@ def metadata(row, source): if field == "source": value = row.get(field, source) elif field == "entry": - value = row.get(field, datetime.datetime.now().strftime("%Y-%m-%d")) + # Parse date field if found, otherwise use current date + value = row.get(field) + value = parser.parse( + value if value else datetime.datetime.now().strftime("%Y-%m-%d") + ) else: value = row.get(field) diff --git a/src/python/paperetl/file/execute.py b/src/python/paperetl/file/execute.py index ef4d897..0629b44 100644 --- a/src/python/paperetl/file/execute.py +++ b/src/python/paperetl/file/execute.py @@ -2,10 +2,14 @@ Transforms and loads medical/scientific files into an articles database. """ +import gzip import os +from multiprocessing import Process, Queue + from ..factory import Factory +from .arx import ARX from .csvf import CSV from .pdf import PDF from .pmb import PMB @@ -17,13 +21,16 @@ class Execute: Transforms and loads medical/scientific files into an articles database. """ + # Completion process signal + COMPLETE = 1 + @staticmethod def mode(source, extension): """ Determines file open mode for source file. Args: - source: text string describing stream source, can be None + source: text string describing stream source extension: data format Returns: @@ -37,67 +44,148 @@ def mode(source, extension): ) @staticmethod - def process(stream, source, extension, config): + def parse(path, source, extension, compress, config): """ - Processes a data input stream and yields articles + Parses articles from file at path. Args: - stream: handle to input data stream - source: text string describing stream source, can be None + path: path to input file + source: text string describing stream source extension: data format config: path to config directory """ - if extension == "pdf": - yield PDF.parse(stream, source) - elif extension == "xml": - if source and source.lower().startswith("pubmed"): - yield from PMB.parse(stream, source, config) - else: - yield TEI.parse(stream, source) - elif extension == "csv": - yield from CSV.parse(stream, source) + print(f"Processing: {path}") + + # Determine if file needs to be open in binary or text mode + mode = Execute.mode(source, extension) + + with gzip.open(path, mode) if compress else open( + path, mode, encoding="utf-8" if mode == "r" else None + ) as stream: + if extension == "pdf": + yield PDF.parse(stream, source) + elif extension == "xml": + if source and source.lower().startswith("arxiv"): + yield from ARX.parse(stream, source) + elif source and source.lower().startswith("pubmed"): + yield from PMB.parse(stream, source, config) + else: + yield TEI.parse(stream, source) + elif extension == "csv": + yield from CSV.parse(stream, source) @staticmethod - def run(indir, url, config=None): + def process(inputs, outputs): """ - Main execution method. + Main worker process loop. Processes file paths stored in inputs and writes articles + to outputs. Writes a final message upon completion. Args: - indir: input directory - url: database url - config: path to config directory, if any + inputs: inputs queue + outputs: outputs queue """ - # Build database connection - db = Factory.create(url) + try: + # Process until inputs queue is exhausted + while not inputs.empty(): + params = inputs.get() + + for result in Execute.parse(*params): + outputs.put(result) + finally: + # Write message that process is complete + outputs.put(Execute.COMPLETE) + + @staticmethod + def scan(indir, config, inputs): + """ + Scans for files in indir and writes to inputs queue. - # Processed ids - ids = set() + Args: + indir: input directory + config: path to config directory, if any + inputs: inputs queue + """ # Recursively walk directory looking for files for root, _, files in sorted(os.walk(indir)): for f in sorted(files): # Extract file extension - extension = f.split(".")[-1].lower() + parts = f.lower().split(".") + extension, compress = ( + (parts[-2], True) if parts[-1] == "gz" else (parts[-1], False) + ) # Check if file ends with accepted extension - if any(extension for ext in ["csv", "pdf", "xml"]): + if any(extension for ext in ["csv", "pdf", "xml"] if ext == extension): # Build full path to file path = os.path.join(root, f) - # Determine if file needs to be open in binary or text mode - mode = Execute.mode(f, extension) + # Write parameters to inputs queue + inputs.put((path, f, extension, compress, config)) + + @staticmethod + def save(processes, outputs, db): + """ + Main consumer loop that saves articles created by worker processes. + + Args: + processes: list of worker processes + outputs: outputs queue + db: output database + """ + + # Read output from worker processes + empty, complete = False, 0 + while not empty: + # Get next result + result = outputs.get() + + # Mark process as complete if all workers are complete and output queue is empty + if result == Execute.COMPLETE: + complete += 1 + empty = len(processes) == complete and outputs.empty() + + # Save article, this method will skip duplicates based on entry date + elif result: + db.save(result) + + @staticmethod + def run(indir, url, config=None, replace=False): + """ + Main execution method. + + Args: + indir: input directory + url: database url + config: path to config directory, if any + replace: if true, a new database will be created, overwriting any existing database + """ + + # Build database connection + db = Factory.create(url, replace) + + # Create queues, limit size of output queue + inputs, outputs = Queue(), Queue(100000) + + # Scan input directory and add files to inputs queue + Execute.scan(indir, config, inputs) - print(f"Processing: {path}") - with open(path, mode) as data: - # Yield articles from input stream - for article in Execute.process(data, f, extension, config): - # Save article if unique - if article and article.uid() not in ids: - db.save(article) - ids.add(article.uid()) + # Start worker processes + processes = [] + for _ in range(min(inputs.qsize(), os.cpu_count())): + process = Process(target=Execute.process, args=(inputs, outputs)) + process.start() + processes.append(process) + + # Read results from worker processes and save to database + Execute.save(processes, outputs, db) # Complete and close database db.complete() db.close() + + # Wait for processes to terminate + for process in processes: + process.join() diff --git a/src/python/paperetl/file/pmb.py b/src/python/paperetl/file/pmb.py index 2a94416..bff3cb3 100644 --- a/src/python/paperetl/file/pmb.py +++ b/src/python/paperetl/file/pmb.py @@ -1,8 +1,7 @@ """ -Transforms and loads PubMed archive XML files into an articles database. +PubMed archive XML processing module """ -import datetime import os import re @@ -70,9 +69,10 @@ def process(element, source, codes): uid = int(citation.find("PMID").text) source = source if source else "PMB" reference = f"https://pubmed.ncbi.nlm.nih.gov/{uid}" + entry = PMB.date(citation.find("DateRevised")) # Journal fields - published = PMB.date(journal) + published = PMB.published(journal) publication = PMB.get(journal, "Title") # Article fields @@ -103,10 +103,10 @@ def process(element, source, codes): title, tags, reference, - datetime.datetime.now().strftime("%Y-%m-%d"), + entry, ) - return Article(metadata, sections, source) + return Article(metadata, sections) return None @@ -141,7 +141,27 @@ def text(element): return "".join(element.itertext()) if element is not None else None @staticmethod - def date(journal): + def date(element): + """ + Attempts to parse a date from an element. + + Args: + element: input element + + Return: + Date if parsed + """ + + date = "" + for field in ["Year", "Month", "Day"]: + value = PMB.get(element, field) + if value: + date += "-" + value if date else value + + return parser.parse(date) if date else None + + @staticmethod + def published(journal): """ Parses the published date. Multiple date formats are handled via the dateparser library. @@ -155,19 +175,14 @@ def date(journal): element = journal.find("JournalIssue/PubDate") - date = "" - for field in ["Year", "Month", "Day"]: - value = PMB.get(element, field) - if value: - date += "-" + value if date else value - + date = PMB.date(element) if not date: - # Attempt to parse out date + # Fallback to MedlineDate date = PMB.get(element, "MedlineDate") date = re.search(r"\d{4}", date) - date = date.group() if date else None + date = parser.parse(date.group()) if date else None - return parser.parse(date) if date else None + return date if date else None @staticmethod def authors(journal): diff --git a/src/python/paperetl/file/tei.py b/src/python/paperetl/file/tei.py index 96edfec..dd9a808 100644 --- a/src/python/paperetl/file/tei.py +++ b/src/python/paperetl/file/tei.py @@ -37,7 +37,14 @@ def parse(stream, source): title = soup.title.text # Extract article metadata - published, publication, authors, reference = TEI.metadata(soup) + ( + published, + publication, + authors, + affiliations, + affiliation, + reference, + ) = TEI.metadata(soup) # Validate parsed data if not title and not reference: @@ -63,15 +70,15 @@ def parse(stream, source): published, publication, authors, - None, - None, + affiliations, + affiliation, title, "PDF", reference, - datetime.datetime.now().strftime("%Y-%m-%d"), + parser.parse(datetime.datetime.now().strftime("%Y-%m-%d")), ) - return Article(metadata, sections, source) + return Article(metadata, sections) @staticmethod def date(published): @@ -101,16 +108,18 @@ def date(published): @staticmethod def authors(source): """ - Builds an authors string from a TEI sourceDesc tag. + Parses authors and associated affiliations from the article. Args: - source: sourceDesc tag handle + elements: authors elements Returns: - semicolon separated list of authors + (semicolon separated list of authors, semicolon separated list of affiliations, primary affiliation) """ authors = [] + affiliations = [] + for name in source.find_all("persname"): surname = name.find("surname") forename = name.find("forename") @@ -118,7 +127,15 @@ def authors(source): if surname and forename: authors.append(f"{surname.text}, {forename.text}") - return "; ".join(authors) + for affiliation in source.find_all("affiliation"): + names = [name.text for name in affiliation.find_all("orgname")] + affiliations.append((", ".join(names))) + + return ( + "; ".join(authors), + "; ".join(dict.fromkeys(affiliations)), + affiliations[-1] if affiliations else None, + ) @staticmethod def metadata(soup): @@ -141,7 +158,7 @@ def metadata(soup): # Parse publication information published = TEI.date(published) publication = publication.text if publication else None - authors = TEI.authors(source) + authors, affiliations, affiliation = TEI.authors(source) struct = soup.find("biblstruct") reference = ( @@ -150,9 +167,16 @@ def metadata(soup): else None ) else: - published, publication, authors, reference = None, None, None, None + published, publication, authors, affiliations, affiliation, reference = ( + None, + None, + None, + None, + None, + None, + ) - return (published, publication, authors, reference) + return (published, publication, authors, affiliations, affiliation, reference) @staticmethod def abstract(soup, title): diff --git a/src/python/paperetl/filesystem.py b/src/python/paperetl/filesystem.py index 458159a..4e7599a 100644 --- a/src/python/paperetl/filesystem.py +++ b/src/python/paperetl/filesystem.py @@ -31,8 +31,8 @@ def __init__(self, outdir): def save(self, article): output = article.uid() + f".{self.extension()}" output = ( - f"{os.path.splitext(article.source)[0]}-{output}" - if article.source + f"{os.path.splitext(article.source())[0]}-{output}" + if article.source() else output ) diff --git a/src/python/paperetl/schema/article.py b/src/python/paperetl/schema/article.py index 16ee2ea..beafec4 100644 --- a/src/python/paperetl/schema/article.py +++ b/src/python/paperetl/schema/article.py @@ -26,19 +26,17 @@ class Article: # Sections schema SECTION = ("name", "text") - def __init__(self, metadata, sections, source): + def __init__(self, metadata, sections): """ Stores article metadata and section content as an object. Args: metadata: article metadata sections: text sections - source: article source """ self.metadata = metadata self.sections = sections - self.source = source def uid(self): """ @@ -50,6 +48,16 @@ def uid(self): return self.metadata[0] + def source(self): + """ + Returns the article source. + + Returns: + article source + """ + + return self.metadata[1] + def tags(self): """ Returns the article tags. @@ -60,6 +68,16 @@ def tags(self): return self.metadata[8] + def entry(self): + """ + Returns the article entry date. + + Returns: + article entry date + """ + + return self.metadata[10] + def build(self): """ Builds an article with all metadata and section content. diff --git a/src/python/paperetl/sqlite.py b/src/python/paperetl/sqlite.py index 2c207d1..5f803ac 100644 --- a/src/python/paperetl/sqlite.py +++ b/src/python/paperetl/sqlite.py @@ -5,7 +5,7 @@ import os import sqlite3 -from datetime import datetime, timedelta +from dateutil import parser from .database import Database @@ -58,25 +58,23 @@ class SQLite(Database): INSERT_ROW = "INSERT INTO {table} ({columns}) VALUES ({values})" CREATE_INDEX = "CREATE INDEX section_article ON sections(article)" - # Merge SQL statements - ATTACH_DB = "ATTACH DATABASE '{path}' as {name}" - DETACH_DB = "DETACH DATABASE '{name}'" - MAX_ENTRY = "SELECT MAX(entry) from {name}.articles" - LOOKUP_ARTICLE = "SELECT Id FROM {name}.articles WHERE Id=? AND Entry = ?" - MERGE_ARTICLE = "INSERT INTO articles SELECT * FROM {name}.articles WHERE Id = ?" - MERGE_SECTIONS = ( - "INSERT INTO sections SELECT * FROM {name}.sections WHERE Article=?" - ) - UPDATE_ENTRY = "UPDATE articles SET entry = ? WHERE Id = ?" - ARTICLE_COUNT = "SELECT COUNT(1) FROM articles" - SECTION_COUNT = "SELECT MAX(id) FROM sections" + # Restore index when updating an existing database + SECTION_COUNT = "SELECT MAX(Id) FROM sections" + + # Lookup entry date for an article + LOOKUP_ENTRY = "SELECT Entry FROM articles WHERE id = ?" - def __init__(self, outdir): + # Delete article + DELETE_ARTICLE = "DELETE FROM articles WHERE id = ?" + DELETE_SECTIONS = "DELETE FROM sections WHERE article = ?" + + def __init__(self, outdir, replace): """ Creates and initializes a new output SQLite database. Args: outdir: output directory + replace: If database should be recreated """ # Create if output path doesn't exist @@ -85,104 +83,93 @@ def __init__(self, outdir): # Output database file dbfile = os.path.join(outdir, "articles.sqlite") - # Delete existing file - if os.path.exists(dbfile): + # Create flag + create = replace or not os.path.exists(dbfile) + + # Delete existing file if replace set + if replace and os.path.exists(dbfile): os.remove(dbfile) # Index fields self.aindex, self.sindex = 0, 0 - # Create output database + # Connect to output database self.db = sqlite3.connect(dbfile) # Create database cursor self.cur = self.db.cursor() - # Create articles table - self.create(SQLite.ARTICLES, "articles") + if create: + # Create articles table + self.create(SQLite.ARTICLES, "articles") - # Create sections table - self.create(SQLite.SECTIONS, "sections") + # Create sections table + self.create(SQLite.SECTIONS, "sections") + + # Create articles index for sections table + self.execute(SQLite.CREATE_INDEX) + else: + # Restore section index id + self.sindex = int(self.cur.execute(SQLite.SECTION_COUNT).fetchone()[0]) + 1 # Start transaction self.cur.execute("BEGIN") - def merge(self, url, ids): - # List of IDs to set for processing - queue = set() - - # Attached database alias - alias = "merge" - - # Attach database - self.db.execute(SQLite.ATTACH_DB.format(path=url, name=alias)) - - # Only process records newer than 5 days before the last run - lastrun = self.cur.execute(SQLite.MAX_ENTRY.format(name=alias)).fetchone()[0] - lastrun = datetime.strptime(lastrun, "%Y-%m-%d") - timedelta(days=5) - lastrun = lastrun.strftime("%Y-%m-%d") - - # Search for existing articles - for uid, date in ids.items(): - self.cur.execute(SQLite.LOOKUP_ARTICLE.format(name=alias), [uid, date]) - if not self.cur.fetchone() and date > lastrun: - # Add uid to process - queue.add(uid) - else: - # Copy existing record - self.cur.execute(SQLite.MERGE_ARTICLE.format(name=alias), [uid]) - self.cur.execute(SQLite.MERGE_SECTIONS.format(name=alias), [uid]) - - # Sync entry date with ids list - self.cur.execute(SQLite.UPDATE_ENTRY, [date, uid]) - - # Set current index positions - self.aindex = ( - int(self.cur.execute(SQLite.ARTICLE_COUNT.format(name=alias)).fetchone()[0]) - + 1 - ) - self.sindex = ( - int(self.cur.execute(SQLite.SECTION_COUNT.format(name=alias)).fetchone()[0]) - + 1 - ) - - # Commit transaction - self.db.commit() - - # Detach database - self.db.execute(SQLite.DETACH_DB.format(name=alias)) + def save(self, article): + # Save article if not a duplicate + if self.savearticle(article): + # Increment number of articles processed + self.aindex += 1 + if self.aindex % 1000 == 0: + print(f"Inserted {self.aindex} articles", end="\r") + + # Commit current transaction and start a new one + self.transaction() + + for name, text in article.sections: + # Section row - id, article, name, text + self.insert( + SQLite.SECTIONS, + "sections", + (self.sindex, article.uid(), name, text), + ) + self.sindex += 1 + + def savearticle(self, article): + """ + Saves an article to SQLite. If a duplicate entry is found, this method compares the entry + date and keeps the article with the latest entry date. - # Start new transaction - self.cur.execute("BEGIN") + Args: + article: article metadata and text content - # Return list of new/updated ids to process - return queue + Returns + True if article saved, False otherwise + """ - def save(self, article): - # Article row - self.insert(SQLite.ARTICLES, "articles", article.metadata) + try: + # Article row + self.insert(SQLite.ARTICLES, "articles", article.metadata) + except sqlite3.IntegrityError: + # Duplicate detected get entry date to determine action + entry = parser.parse( + self.cur.execute(SQLite.LOOKUP_ENTRY, [article.uid()]).fetchone()[0] + ) - # Increment number of articles processed - self.aindex += 1 - if self.aindex % 1000 == 0: - print(f"Inserted {self.aindex} articles", end="\r") + # Keep existing article if existing entry date is same or newer + if article.entry() <= entry: + return False - # Commit current transaction and start a new one - self.transaction() + # Delete and re-insert article + self.cur.execute(SQLite.DELETE_ARTICLE, [article.uid()]) + self.cur.execute(SQLite.DELETE_SECTIONS, [article.uid()]) + self.insert(SQLite.ARTICLES, "articles", article.metadata) - for name, text in article.sections: - # Section row - id, article, name, text - self.insert( - SQLite.SECTIONS, "sections", (self.sindex, article.uid(), name, text) - ) - self.sindex += 1 + return True def complete(self): print(f"Total articles inserted: {self.aindex}") - # Create articles index for sections table - self.execute(SQLite.CREATE_INDEX) - def close(self): self.db.commit() self.db.close() @@ -208,11 +195,7 @@ def create(self, table, name): create = SQLite.CREATE_TABLE.format(table=name, fields=", ".join(columns)) # pylint: disable=W0703 - try: - self.cur.execute(create) - except Exception as e: - print(create) - print("Failed to create table: " + e) + self.cur.execute(create) def execute(self, sql): """ @@ -240,12 +223,8 @@ def insert(self, table, name, row): table=name, columns=", ".join(columns), values=("?, " * len(columns))[:-2] ) - try: - # Execute insert statement - self.cur.execute(insert, self.values(table, row, columns)) - # pylint: disable=W0703 - except Exception as ex: - print(f"Error inserting row: {row}", ex) + # Execute insert statement + self.cur.execute(insert, self.values(table, row, columns)) def values(self, table, row, columns): """ diff --git a/test/python/testcord19.py b/test/python/testcord19.py index 1c47835..0fafc48 100644 --- a/test/python/testcord19.py +++ b/test/python/testcord19.py @@ -2,15 +2,12 @@ CORD-19 tests """ -import os -import shutil import sqlite3 from datetime import datetime from paperetl.cord19.entry import Entry from paperetl.cord19.execute import Execute -from paperetl.factory import Factory # pylint: disable = C0411 from testprocess import TestProcess @@ -34,7 +31,6 @@ def setUpClass(cls): Utils.CORD19 + "/models", Utils.CORD19 + "/data/entry-dates.csv", True, - None, ) def setUp(self): @@ -123,59 +119,6 @@ def testHash(self): "47ed55bfa014cd59f58896c132c36bb0a218d11d", ) - def testMergeEmpty(self): - """ - Test merge run with no updates - """ - - os.makedirs(Utils.CORD19 + "/merge", exist_ok=True) - - # Copy existing articles.sqlite file - shutil.copyfile( - Utils.CORD19 + "/models/articles.sqlite", - Utils.CORD19 + "/merge/articles.v1.sqlite", - ) - - db = Factory.create(Utils.CORD19 + "/merge") - - # Load entry dates - dates = Execute.entryDates( - Utils.CORD19 + "/data", Utils.CORD19 + "/data/entry-dates.csv" - ) - - # Run merge process - merge = db.merge(Utils.CORD19 + "/merge/articles.v1.sqlite", dates) - db.close() - - # Assert no records to merge - self.assertFalse(merge) - - def testMergeUpdate(self): - """ - Test merge run with updates - """ - - # Run merge again settings entry date to older date to ensure id is set to merge - db = sqlite3.connect(Utils.CORD19 + "/merge/articles.v1.sqlite") - db.execute("UPDATE articles SET entry='2020-01-01' WHERE id='mb0qcd0b'") - db.commit() - db.close() - - # Run merge, should merge a single record - db = Factory.create(Utils.CORD19 + "/merge") - - # Load entry dates - dates = Execute.entryDates( - Utils.CORD19 + "/data", Utils.CORD19 + "/data/entry-dates.csv" - ) - - # Run merge process - merge = db.merge(Utils.CORD19 + "/merge/articles.v1.sqlite", dates) - db.close() - - # Assert record to merge - self.assertEqual(merge, {"mb0qcd0b"}) - def testSectionCount(self): """ Test number of sections diff --git a/test/python/testelastic.py b/test/python/testelastic.py new file mode 100644 index 0000000..7d89e52 --- /dev/null +++ b/test/python/testelastic.py @@ -0,0 +1,54 @@ +""" +Elastic tests +""" + +import unittest + +from unittest import mock + +from paperetl.elastic import Elastic +from paperetl.schema.article import Article + + +class Indices: + """ + Mock elasticsearch class for testing + """ + + exists = lambda *args: True + delete = lambda *args: True + refresh = lambda *args, **kwargs: True + + +class ElasticStub: + """ + Mock elasticsearch class for testing + """ + + indices = Indices() + bulk = lambda *args: True + + +class TestElastic(unittest.TestCase): + """ + Elastic tests + """ + + @mock.patch( + "paperetl.elastic.Elasticsearch", mock.MagicMock(return_value=ElasticStub()) + ) + @mock.patch("paperetl.elastic.helpers", mock.MagicMock(return_value=ElasticStub())) + def testSave(self): + """ + Tests saving an article to elasticsearch + """ + + # Create connection + elastic = Elastic("http://localhost:9200", False) + + # Save mock results + for _ in range(1000): + elastic.save(Article(Article.ARTICLE, [("name", "text")])) + + # Mark as complete + elastic.complete() diff --git a/test/python/testfiledatabase.py b/test/python/testfiledatabase.py index 7ebc91e..68aa250 100644 --- a/test/python/testfiledatabase.py +++ b/test/python/testfiledatabase.py @@ -4,13 +4,27 @@ import sqlite3 +from unittest import mock + from paperetl.file.execute import Execute +from paperetl.file.pdf import PDF # pylint: disable = C0411 from testprocess import TestProcess from utils import Utils +class RequestsStub: + """ + Mock requests class for testing. + """ + + def __init__(self): + self.ok = True + with open(Utils.FILE + "/data/0.xml", "r", encoding="utf-8") as xml: + self.text = xml.read() + + class TestFileDatabase(TestProcess): """ File ETL to database tests @@ -19,10 +33,13 @@ class TestFileDatabase(TestProcess): @classmethod def setUpClass(cls): """ - One-time initialization. Run File ETL process for the test dataset. + One-time initialization. Run file ETL process for the test dataset. """ # Build articles database + Execute.run(Utils.FILE + "/data", Utils.FILE + "/models", None, True) + + # Run again with replace=False Execute.run(Utils.FILE + "/data", Utils.FILE + "/models") def setUp(self): @@ -41,7 +58,7 @@ def testArticleCount(self): Test number of articles """ - self.articleCount(17) + self.articleCount(20) def testArticles(self): """ @@ -49,33 +66,52 @@ def testArticles(self): """ hashes = { - "00398e4c637f5e5447e35e63669187f0239c0357": "e4bab29931654fbf2569fe88c9947138", - "00c4c8c42473d25ebb38c4a8a14200c6900be2e9": "8a9deb80f42173aa0681d7607925c63b", + "00398e4c637f5e5447e35e63669187f0239c0357": "769aabf322421b2e34a32d7afce4d046", + "00c4c8c42473d25ebb38c4a8a14200c6900be2e9": "c1b8cebfb55231215a865eaa8e16d338", "1000": "babc1842c2dd9bf298bf6376a1b58318", "1001": "d8579348f06c6428565cab60ef797d0d", - "17a845a8681cca77a4497462e797172148448d7d": "8b5e4696e66934afe75e1d9d14aeb445", - "1d6a755d67e76049551898de66c95f77b9420b0c": "b1cac46801d2dd58ec2df9ce14986af2", + "17a845a8681cca77a4497462e797172148448d7d": "f1768e230244ab984530c482f275f439", + "1d6a755d67e76049551898de66c95f77b9420b0c": "a239023786db03a0bc7ae00780c396df", + "3a1e7ec128ae12937badcd33f2a273b284714550": "37f1dd1e15347ae325aa09901d1767f3", "33024096": "c6d63a5a2761519f31cff6f690b9f639", "33046957": "62634d987e1d5077f9892a059fec8302", "33100476": "ba7c2509e242b2132d32baa48a1dc2ed", "33126180": "d8fe2cb1d95ddf74d79e6a5e2c58bd07", "33268238": "3f0f851d08c0138d9946db1641a5278e", - "3d2fb136bbd9bd95f86fc49bdcf5ad08ada6913b": "3facd9a50cf7a0e038ed0e6b3903a8b0", - "5ea7c57e339a078196ec69223c4681fd7a5aab8b": "3251a20e067e1bd5291e56e9cb20218e", - "6cb7a79749913fa0c2c3748cbfee2f654d5cea36": "d14bebf4a6f4103498b02b5604649619", - "a09f0fcf41e01f2cdb5685b5000964797f679132": "ef24259c459d09612d0e6d4430b3e8bd", - "b9f6e3d2dd7d18902ac3a538789d836793dd48b2": "60bb318db90bfb305b2343c54242d689", - "dff0088d65a56e2673d11ad2f7a180687cab6f70": "4fc1512bbceb439fd4ff4a7132b08735", + "3d2fb136bbd9bd95f86fc49bdcf5ad08ada6913b": "6920b2807d6955a8b2fefd5b0a59219e", + "5ea7c57e339a078196ec69223c4681fd7a5aab8b": "26a92c9d288412e73476c3c1f3107e20", + "68d85e38ab25365d8a382b146305aa8560bfa6fa": "e45f785edc439dfea64b6ff929fa9b44", + "6cb7a79749913fa0c2c3748cbfee2f654d5cea36": "53ff57038e6257a5fa8032cdc40b0750", + "a09f0fcf41e01f2cdb5685b5000964797f679132": "d2a98edeb923ec781a44bf7c166bfc77", + "b9f6e3d2dd7d18902ac3a538789d836793dd48b2": "a7b77a4e68f9134e134c306b00d489d8", + "da60cb944e50dbecbfd217581cb5f55bda332d7a": "a39676892934b50d7cd2e299f3a68d21", + "dff0088d65a56e2673d11ad2f7a180687cab6f70": "f22696dbdce156d5a81b7231d1ef983b", } self.articles(hashes) + @mock.patch( + "paperetl.file.pdf.requests.post", mock.MagicMock(return_value=RequestsStub()) + ) + def testPDF(self): + """ + Tests parsing PDFs + """ + + article = PDF.parse("stream", "source") + + # Calculate metadata hash + md5 = Utils.hashtext(" ".join([str(x) for x in article.metadata[:-1]])) + + self.assertEqual(md5, "0e04572de9c87fdbdc5339b03aee1df9") + self.assertEqual(len(article.sections), 254) + def testSectionCount(self): """ Test number of sections """ - self.sectionCount(3646) + self.sectionCount(3668) def testSections(self): """ @@ -89,6 +125,7 @@ def testSections(self): "1001": "147aca4eade4737a7a1d438a5a1d3ed1", "17a845a8681cca77a4497462e797172148448d7d": "8c50748e74883ac316d2473cf491d4e0", "1d6a755d67e76049551898de66c95f77b9420b0c": "799d21a7c67b5b4777effb2c83a42ff4", + "3a1e7ec128ae12937badcd33f2a273b284714550": "2a0e6610a4edce73005ddaf740117aee", "33024096": "485b6035e1cd62e5ded8c356acc5689f", "33046957": "1e7833214fc60b89316f3680e9f93ec1", "33100476": "a5db58cd71ba75e76fa8df24e3336db6", @@ -96,9 +133,11 @@ def testSections(self): "33268238": "1077e7a114d54bdf80de5d6834fdeb63", "3d2fb136bbd9bd95f86fc49bdcf5ad08ada6913b": "1fc2ccc509b2bd7eca33858c740e54c2", "5ea7c57e339a078196ec69223c4681fd7a5aab8b": "e2e28cb740520bae95acd373a6c767a9", + "68d85e38ab25365d8a382b146305aa8560bfa6fa": "9772fa38bdafd56dc1599ce02ff10c1e", "6cb7a79749913fa0c2c3748cbfee2f654d5cea36": "2387f11ea786cc4265c54f096080ad00", "a09f0fcf41e01f2cdb5685b5000964797f679132": "78596af571f3250057c1df23eabfc498", "b9f6e3d2dd7d18902ac3a538789d836793dd48b2": "51d5b0cf2273a687a348502c95c6dbec", + "da60cb944e50dbecbfd217581cb5f55bda332d7a": "73f7744629e4feb5988acbd31117c7f1", "dff0088d65a56e2673d11ad2f7a180687cab6f70": "c61df238a8ecb9a63422f19b2218949d", } diff --git a/test/python/testfileexport.py b/test/python/testfileexport.py index fee3dd2..97ea1f4 100644 --- a/test/python/testfileexport.py +++ b/test/python/testfileexport.py @@ -35,7 +35,7 @@ def testYAML(self): def export(self, method): """ - Test a file export. + Test a file export Args: method: export method (json or yaml) @@ -64,5 +64,5 @@ def export(self, method): sections += len(data["sections"]) # Validate counts - self.assertEqual(articles, 17) - self.assertEqual(sections, 3646) + self.assertEqual(articles, 20) + self.assertEqual(sections, 3668)