Skip to content

Commit

Permalink
Fixed many bugs (#94)
Browse files Browse the repository at this point in the history
* Made a bunch of fixes

* Using Python 3.11 in CI
  • Loading branch information
jamesbraza authored Jun 4, 2024
1 parent 161e27f commit 29c11f0
Show file tree
Hide file tree
Showing 6 changed files with 426 additions and 115 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- uses: actions/setup-python@v5
with:
cache: pip
python-version: 3.12
python-version: 3.11
- run: python -m pip install .[dev]
- uses: pre-commit/action@v3.0.1
- name: test
Expand Down
112 changes: 80 additions & 32 deletions paperscraper/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import re
import sys
from collections.abc import Awaitable, Callable
from collections.abc import Iterable
from enum import Enum, IntEnum, auto
from functools import partial
from pathlib import Path
Expand All @@ -20,6 +20,7 @@
from .scraper import Scraper
from .utils import (
ThrottledClientSession,
crossref_headers,
encode_id,
find_doi,
get_scheme_hostname,
Expand All @@ -29,7 +30,7 @@
year_extract_pattern = re.compile(r"\b\d{4}\b")


def clean_upbibtex(bibtex):
def clean_upbibtex(bibtex: str) -> str:
# WTF Semantic Scholar?
mapping = {
"None": "article",
Expand Down Expand Up @@ -84,7 +85,7 @@ def format_bibtex(bibtex, key, clean: bool = True) -> str:
try:
entry = style.format_entry(label="1", entry=bd.entries[key])
return entry.text.render_as("text")
except FieldIsMissing:
except (FieldIsMissing, UnicodeDecodeError):
try:
return bd.entries[key].fields["title"]
except KeyError as exc:
Expand Down Expand Up @@ -306,10 +307,8 @@ async def local_scraper(paper, path) -> bool: # noqa: ARG001
return True


def default_scraper(
callback: Callable[[str, dict[str, str]], Awaitable] | None = None,
) -> Scraper:
scraper = Scraper(callback=callback)
def default_scraper(**scraper_kwargs) -> Scraper:
scraper = Scraper(**scraper_kwargs)
scraper.register_scraper(local_scraper, priority=12)
scraper_rate_limit_config: dict[str, Any] = {
"attach_session": True,
Expand Down Expand Up @@ -401,11 +400,37 @@ async def preprocess_google_scholar_metadata( # noqa: C901
return paper


async def parallel_preprocess_google_scholar_metadata(
papers: Iterable[dict[str, Any]],
session: ClientSession,
logger: logging.Logger | None = None,
) -> list[dict[str, Any]]:
"""
Preprocess papers in parallel, discarding ones with preprocessing failures.
NOTE: this function does not preserve the order of papers due to variable
preprocessing times.
"""
preprocessed_papers = []

async def index(paper: dict[str, Any]) -> None:
try:
preprocessed_papers.append(
await preprocess_google_scholar_metadata(paper, session)
)
except DOINotFoundError:
if logger:
logger.exception(f"Failed to find a DOI for paper {paper}.")

await asyncio.gather(*(index(p) for p in papers))
return preprocessed_papers


async def parse_google_scholar_metadata(
paper: dict[str, Any], session: ClientSession
) -> dict[str, Any]:
"""Parse pre-processed paper metadata from Google Scholar into a richer format."""
doi: str | None = paper["externalIds"].get("DOI")
doi: str | None = (paper.get("externalIds") or {}).get("DOI")
citation: str | None = None
if doi:
try:
Expand All @@ -416,7 +441,7 @@ async def parse_google_scholar_metadata(
doi = None
except CitationConversionError:
citation = None
if not doi or not citation:
if (not doi or not citation) and "inline_links" in paper:
# get citation by following link
# SLOW SLOW Using SerpAPI for this
async with session.get(
Expand All @@ -442,13 +467,23 @@ async def parse_google_scholar_metadata(
f"{msg} bibtex link {bibtex_link} for paper {paper}."
) from exc
bibtex = await r.text()
if not bibtex.strip().startswith("@"):
raise RuntimeError(
f"Google scholar ip block bibtex link {bibtex_link} for paper"
f" {paper}."
)
key = bibtex.split("{")[1].split(",")[0]

if not citation:
raise RuntimeError(
f"Exhausted all options for citation retrieval for {paper!r}"
)
return {
"citation": citation,
"key": key,
"bibtex": bibtex,
"year": paper["year"],
"url": paper["link"],
"url": paper.get("link"),
"paperId": paper["paperId"],
"doi": paper["externalIds"].get("DOI"),
"citationCount": paper["citationCount"],
Expand All @@ -457,6 +492,14 @@ async def parse_google_scholar_metadata(


async def reconcile_doi(title: str, authors: list[str], session: ClientSession) -> str:
"""
Look up a DOI given a title and author list using Crossref.
Raises:
DOINotFoundError: If the reconciliation fails due to (1) Crossref API call had
non-'ok' status code, (2) Crossref API response status indicates failure, or
(3) Crossref response's entry had a low score.
"""
# do not want initials
authors_query = " ".join([a for a in authors if len(a) > 1])
mailto = os.environ.get("CROSSREF_MAILTO", "paperscraper@example.org")
Expand All @@ -470,9 +513,11 @@ async def reconcile_doi(title: str, authors: list[str], session: ClientSession)
}
if authors_query:
params["query.author"] = authors_query
async with session.get(url, params=params) as r:
if not r.ok:
raise DOINotFoundError("Could not reconcile DOI " + title)
async with session.get(url, params=params, headers=crossref_headers()) as r:
try:
r.raise_for_status()
except ClientResponseError as exc:
raise DOINotFoundError("Could not reconcile DOI " + title) from exc
data = await r.json()
if data["status"] == "failed":
raise DOINotFoundError(f"Could not find DOI for {title}")
Expand All @@ -487,7 +532,7 @@ async def reconcile_doi(title: str, authors: list[str], session: ClientSession)
async def doi_to_bibtex(doi: str, session: ClientSession) -> str:
# get DOI via crossref
url = f"https://api.crossref.org/works/{doi}/transform/application/x-bibtex"
async with session.get(url) as r:
async with session.get(url, headers=crossref_headers()) as r:
if not r.ok:
raise DOINotFoundError(
f"Per HTTP status code {r.status}, could not resolve DOI {doi}."
Expand Down Expand Up @@ -812,7 +857,7 @@ async def google2s2(

responses = await asyncio.gather(*(
google2s2(t, y, p)
for t, y, p in zip(titles, years, google_pdf_links)
for t, y, p in zip(titles, years, google_pdf_links, strict=True)
))
data = {"data": [r for r in responses if r is not None]}
data["total"] = len(data["data"])
Expand Down Expand Up @@ -841,11 +886,11 @@ async def google2s2(
paths.update(
await scraper.batch_scrape(
papers,
pdir,
parse_semantic_scholar_metadata,
batch_size,
limit,
logger,
paper_file_dump_dir=pdir,
paper_parser=parse_semantic_scholar_metadata,
batch_size=batch_size,
limit=limit,
logger=logger,
)
)
if search_type in ["default", "google"] and len(paths) < limit and has_more_data:
Expand Down Expand Up @@ -896,6 +941,10 @@ async def a_gsearch_papers( # noqa: C901
logger.addHandler(ch)
# SEE: https://serpapi.com/google-scholar-api
endpoint = "https://serpapi.com/search.json"
# adjust _limit if limit is smaller (with margin for scraping errors)
# for example, if limit is 3 we would be fine only getting 8 results
# but if limit is 50, this will just return normal default _limit (20)
_limit = min(_limit, limit + 5)
params = {
"q": query,
"api_key": os.environ["SERPAPI_API_KEY"],
Expand Down Expand Up @@ -938,18 +987,14 @@ async def a_gsearch_papers( # noqa: C901
) as response:
if not response.ok:
raise RuntimeError(
f"Error searching papers: {response.status} {response.reason} {await response.text()}" # noqa: E501
"Error searching papers:"
f" {response.status} {response.reason} {await response.text()}"
)
data = await response.json()

if "organic_results" not in data:
return paths
papers = data["organic_results"]

# we only process papers that have a link
papers = await asyncio.gather(
*(preprocess_google_scholar_metadata(p, session) for p in papers)
)
total_papers = data["search_information"].get("total_results", 1)
logger.info(
f"Found {total_papers} papers, analyzing {_offset} to"
Expand All @@ -959,12 +1004,15 @@ async def a_gsearch_papers( # noqa: C901
# batch them, since we may reach desired limit before all done
paths.update(
await scraper.batch_scrape(
papers,
pdir,
partial(parse_google_scholar_metadata, session=session),
batch_size,
limit,
logger,
# we only process papers that have a link and a DOI
await parallel_preprocess_google_scholar_metadata(
papers, session, logger
),
paper_file_dump_dir=pdir,
paper_parser=partial(parse_google_scholar_metadata, session=session),
batch_size=batch_size,
limit=limit,
logger=logger,
)
)
if len(paths) < limit and _offset + _limit < total_papers:
Expand Down
5 changes: 4 additions & 1 deletion paperscraper/log_formatter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import logging
from typing import ClassVar


class CustomFormatter(logging.Formatter):
Expand All @@ -12,7 +15,7 @@ class CustomFormatter(logging.Formatter):
"%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
)

FORMATS = { # noqa: RUF012
FORMATS: ClassVar[dict[int, str]] = {
logging.DEBUG: grey + format_str + reset,
logging.INFO: grey + format_str + reset,
logging.WARNING: yellow + format_str + reset,
Expand Down
64 changes: 49 additions & 15 deletions paperscraper/scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import os
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass
from typing import Any, Literal
from typing import Any, ClassVar, Literal

from .headers import get_header
from .utils import ThrottledClientSession, check_pdf
from .utils import ThrottledClientSession, aidentity_fn, check_pdf


@dataclass
Expand Down Expand Up @@ -51,14 +51,32 @@ def register_scraper(
# sort scrapers by priority
self.scrapers.sort(key=lambda x: x.priority, reverse=True)
# reshape into sorted scrapers
self._build_sorted_scrapers()

try:
SCRAPE_FUNCTION_TIMEOUT: ClassVar[float | None] = float( # sec
os.environ.get("PAPERSCRAPER_SCRAPE_FUNCTION_TIMEOUT", "60")
)
except ValueError: # Defeat by setting to "None"
SCRAPE_FUNCTION_TIMEOUT = None

def _build_sorted_scrapers(self) -> None:
self.sorted_scrapers = [
[s for s in self.scrapers if s.priority == priority]
for priority in sorted({s.priority for s in self.scrapers})
]

def deregister_scraper(self, name: str) -> None:
"""Remove a scraper by name."""
for i, scraper in enumerate(self.scrapers):
if scraper.name == name:
self.scrapers.pop(i)
break
self._build_sorted_scrapers()

async def scrape(
self,
paper,
paper: dict[str, Any],
path: str | os.PathLike,
i: int = 0,
logger: logging.Logger | None = None,
Expand All @@ -79,7 +97,10 @@ async def scrape(
for j in range(len(scrapers)):
scraper = scrapers[(i + j) % len(scrapers)]
try:
result = await scraper.function(paper, path, **scraper.kwargs)
result = await asyncio.wait_for(
scraper.function(paper, path, **scraper.kwargs),
timeout=self.SCRAPE_FUNCTION_TIMEOUT,
)
if result and (
not scraper.check_pdf or check_pdf(path, logger or False)
):
Expand Down Expand Up @@ -107,9 +128,10 @@ async def batch_scrape(
self,
papers: Sequence[dict[str, Any]],
paper_file_dump_dir: str | os.PathLike,
paper_parser: (
Callable[[dict[str, Any]], Awaitable[dict[str, Any]]] | None
) = None,
paper_preprocessor: Callable[[Any], Awaitable[dict[str, Any]]] = aidentity_fn,
paper_parser: Callable[
[dict[str, Any]], Awaitable[dict[str, Any]]
] = aidentity_fn,
batch_size: int = 10,
limit: int | None = None,
logger: logging.Logger | None = None,
Expand All @@ -120,7 +142,9 @@ async def batch_scrape(
Args:
papers: List of raw paper metadata.
paper_file_dump_dir: Directory where papers will be downloaded.
paper_parser: Optional function to process the raw paper metadata
paper_preprocessor: Optional async function to process the raw paper
metadata before scraping.
paper_parser: Optional async function to process the raw paper metadata
after scraping.
batch_size: Batch size to use when scraping, within a batch
scraping is parallelized.
Expand All @@ -130,19 +154,29 @@ async def batch_scrape(
Returns:
Dictionary mapping path to downloaded paper to parsed metadata.
"""
if paper_parser is not None:
parser = paper_parser
else:

async def parser(paper: dict[str, Any]) -> dict[str, Any]:
return paper

async def scrape_parse(
paper: dict[str, Any], i: int
) -> tuple[str, dict[str, Any]] | Literal[False]:
try:
paper = await paper_preprocessor(paper)
except RuntimeError: # Failed to hydrate the required paperId
if logger is not None:
logger.exception(f"Failed to preprocess paper {paper}.")
return False
path = os.path.join(paper_file_dump_dir, f'{paper["paperId"]}.pdf')
success = await self.scrape(paper, path, i=i, logger=logger)
return (path, await parser(paper)) if success else False
try:
return (path, await paper_parser(paper)) if success else False
except RuntimeError:
# RuntimeError: failed to traverse link inside paper details,
# or paper is missing field required for parsing like BibTeX links
if logger is not None:
logger.exception(
f"Failed to parse paper titled {paper.get('title')!r} with key"
f" {paper.get('paperId')!r}."
)
return False

aggregated: dict[str, dict[str, Any]] = {}
for i in range(0, len(papers), batch_size):
Expand Down
Loading

0 comments on commit 29c11f0

Please sign in to comment.