diff --git a/.gitignore b/.gitignore index d7de04c..a01064e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ workspace.code-workspace tests __pycache__ .env +build diff --git a/README.md b/README.md index fecfcf2..2be5181 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ Run anywhere: ### Modules Papercast is designed around 3 types of modules: -- **Collectors** convert documents to a usable format (plaintext for now). +- **Processors** convert documents to a usable format (plaintext for now). - **Processors** process the document - **Publishers** publish the audio to your desired endpoint (e.g. a podcast feed). diff --git a/docs/source/api_reference/collectors.rst b/docs/source/api_reference/collectors.rst index 93238cb..3a53505 100644 --- a/docs/source/api_reference/collectors.rst +++ b/docs/source/api_reference/collectors.rst @@ -1,5 +1,5 @@ -Collectors +Processors ========== .. toctree:: diff --git a/docs/source/api_reference/collectors/arxivcollector.rst b/docs/source/api_reference/collectors/arxivcollector.rst index 5032d5b..2eabb15 100644 --- a/docs/source/api_reference/collectors/arxivcollector.rst +++ b/docs/source/api_reference/collectors/arxivcollector.rst @@ -1,8 +1,8 @@ -ArxivCollector +ArxivProcessor ============== -.. autoclass:: papercast.collectors.ArxivCollector +.. autoclass:: papercast.collectors.ArxivProcessor :members: :undoc-members: diff --git a/docs/source/cli/cli.md b/docs/source/cli/cli.md new file mode 100644 index 0000000..ebfb80e --- /dev/null +++ b/docs/source/cli/cli.md @@ -0,0 +1,5 @@ +# CLI +```{toctree} +:hidden: +:maxdepth: 1 +``` \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index f47e5cd..8b78402 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -135,7 +135,7 @@ def get_plugin_classes(plugin_folder) -> Dict[str, List[str]]: rst_path.parent.mkdir(parents=True, exist_ok=True) rst_content = textwrap.dedent(""" - Collectors + Processors ========== .. toctree:: diff --git a/docs/source/examples/config.py b/docs/source/examples/config.py deleted file mode 100644 index cc71bcd..0000000 --- a/docs/source/examples/config.py +++ /dev/null @@ -1,19 +0,0 @@ -from papercast.pipelines import Pipeline -from papercast.collectors import SemanticScholarCollector, ArxivCollector, PDFCollector -from papercast.narrators import Narrator, SayNarrator, PollyNarrator -from papercast.extractors import Extractor, GROBIDExtractor -from papercast.publishers import Publisher, GithubPagesPodcastPublisher - - -class MyPipeline(Pipeline): - def __init__(self): - self.name = "default" - self.collectors = [ - SemanticScholarCollector(input_names={"id": "ss_id"}), - ArxivCollector(), - PDFCollector(), - ] - self.extractors = [GROBIDExtractor(input_names={"pdf": "pdf"})] - self.narrators = [PollyNarrator()] - self.filters = [] - self.publishers = [GithubPagesPodcastPublisher()] diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 1cdaf05..4aab2e6 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -30,7 +30,7 @@ Here's an example: [^1]: (possible support for DAG structure in the future) -[^2]: Collectors are slightly different in that they start with document identifiers and produce Productions +[^2]: Processors are slightly different in that they start with document identifiers and produce Productions ## Use the CLI to start the server ```bash diff --git a/docs/source/index.md b/docs/source/index.md index c2f0657..bed3ee9 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -28,7 +28,7 @@ An extensible framework to turn technical documents into multimedia. Written in Papercast is designed around 3 types of modules: -- [Collectors](modules/collectors.md) accept document identifiers and return [Productions](modules/productions.md). +- [Processors](modules/collectors.md) accept document identifiers and return [Productions](modules/productions.md). - [Processors](modules/processors.md) modify Productions. - [Publishers](modules/publishers.md) publish Productions to your desired endpoint (e.g. a podcast feed, Twitter (coming soon), etc). diff --git a/docs/source/modules/collectors.md b/docs/source/modules/collectors.md index 5a8ed88..6c5eb54 100644 --- a/docs/source/modules/collectors.md +++ b/docs/source/modules/collectors.md @@ -1,4 +1,4 @@ -# Collectors +# Processors ```{toctree} :hidden: :maxdepth: 1 diff --git a/docs/source/modules/pipelines.md b/docs/source/modules/pipelines.md index 1143707..0aeadce 100644 --- a/docs/source/modules/pipelines.md +++ b/docs/source/modules/pipelines.md @@ -6,13 +6,13 @@ Papercast pipelines are a way to chain together a series of collectors, processors, and publishers. -Collectors accept document identifiers and return Productions. +Processors accept document identifiers and return Productions. Processors accept Productions and return Productions. Publishers accept Productions and return nothing in the Python environment, but may publish the Production to a remote location. -Pipelines are constructed by connecting Collectors, Processors, and Publishers together. The output of one component is the input of the next component. +Pipelines are constructed by connecting Processors, Processors, and Publishers together. The output of one component is the input of the next component. To make pipeline components interoperable, they operate on a set of types. Papercast provides a set of common types. More exotic use cases may require custom types. diff --git a/docs/source/modules/subscribers.md b/docs/source/modules/subscribers.md new file mode 100644 index 0000000..49335fb --- /dev/null +++ b/docs/source/modules/subscribers.md @@ -0,0 +1,8 @@ +# Subscribers + +Subscribers can be used to wait for events and trigger pipeline runs when they occur. + +Each subscriber implements a .subscribe() method that returns an async iterator. +The iterator yields events as they occur. +The subscriber is responsible for filtering out events that are not relevant to the pipeline. + diff --git a/examples/arxiv-grobid-say-github-pages/server.py b/examples/arxiv-grobid-say-github-pages/server.py index 7a24ac7..0e9c388 100644 --- a/examples/arxiv-grobid-say-github-pages/server.py +++ b/examples/arxiv-grobid-say-github-pages/server.py @@ -1,9 +1,9 @@ from papercast.pipelines import Pipeline -from papercast.collectors.arxiv import ArxivCollector -from papercast.collectors.pdf import PDFCollector -from papercast.processors.say import SayProcessor +from papercast.processors import ArxivProcessor +from papercast.processors import PDFProcessor +from papercast.processors import SayProcessor from papercast.processors import GROBIDProcessor -from papercast.publishers.github_pages import GithubPagesPodcastPublisher +from papercast.publishers import GithubPagesPodcastPublisher from papercast.server import Server # Create a pipeline @@ -11,10 +11,10 @@ # Add processors to the pipeline pipeline.add_processor( - "arxiv", ArxivCollector(pdf_dir="data/pdfs", json_dir="data/json") + "arxiv", ArxivProcessor(pdf_dir="data/pdfs", json_dir="data/json") ) -pipeline.add_processor("pdf", PDFCollector(pdf_dir="data/pdfs")) +pipeline.add_processor("pdf", PDFProcessor(pdf_dir="data/pdfs")) pipeline.add_processor( "grobid", diff --git a/papercast/base.py b/papercast/base.py index 25334cd..970be31 100644 --- a/papercast/base.py +++ b/papercast/base.py @@ -2,6 +2,9 @@ from abc import ABC, abstractmethod from functools import wraps from typing import Any, Dict +import websockets +from websockets.client import connect +from typing import AsyncGenerator, AsyncIterable from papercast.production import Production @@ -36,64 +39,58 @@ def __hash__(self): def __eq__(self, other): return id(self) == id(other) - @abstractmethod - def process(self, input: Production, *args, **kwargs) -> Production: - raise NotImplementedError - + def init_logger(self, log_level: int = logging.INFO): + self.logger = logging.getLogger(__name__) + c_handler = logging.StreamHandler() + c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") + c_handler.setLevel(log_level) + c_handler.setFormatter(c_format) + self.logger.addHandler(c_handler) class BaseProcessor(BasePipelineComponent, ABC): + input_types: Dict[str, Any] = {} + output_types: Dict[str, Any] = {} + def __init__( self, ) -> None: self.init_logger() self.name = None - def init_logger(self, log_level: int = logging.INFO): - self.logger = logging.getLogger(__name__) - c_handler = logging.StreamHandler() - c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") - c_handler.setLevel(log_level) - c_handler.setFormatter(c_format) - self.logger.addHandler(c_handler) + @abstractmethod + @validate_inputs + def process(self, input: Production, *args, **kwargs) -> Production: + raise NotImplementedError + + def from_kwargs(self, **kwargs): + production = Production(**kwargs) + return self.process(production) -class BaseCollector(BasePipelineComponent, ABC): + +class BaseSubscriber(BasePipelineComponent, ABC): def __init__( self, ) -> None: - pass - - def init_logger(self, log_level: int = logging.INFO): - self.logger = logging.getLogger(__name__) - c_handler = logging.StreamHandler() - c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") - c_handler.setLevel(log_level) - c_handler.setFormatter(c_format) - self.logger.addHandler(c_handler) + self.init_logger() @abstractmethod - def process(self, *args, **kwargs) -> Production: + async def subscribe(self) -> Production: raise NotImplementedError class BasePublisher(BasePipelineComponent, ABC): + input_types: Dict[str, Any] = {} + def __init__( self, ) -> None: pass - def init_logger(self, log_level: int = logging.INFO): - self.logger = logging.getLogger(__name__) - c_handler = logging.StreamHandler() - c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") - c_handler.setLevel(log_level) - c_handler.setFormatter(c_format) - self.logger.addHandler(c_handler) - @abstractmethod def process(self, input: Production, *args, **kwargs) -> None: raise NotImplementedError - @abstractmethod - def input_types(self) -> Dict[str, Any]: - pass + def from_kwargs(self, **kwargs): + production = Production(**kwargs) + return self.process(production) diff --git a/papercast/collectors/__init__.py b/papercast/collectors/__init__.py index a1a4b98..fd544e8 100644 --- a/papercast/collectors/__init__.py +++ b/papercast/collectors/__init__.py @@ -1,6 +1,6 @@ -from papercast.plugin_utils import load_plugins +# from papercast.plugin_utils import load_plugins -_installed_plugins = load_plugins("collectors") +# _installed_plugins = load_plugins("collectors") -for name, plugin in _installed_plugins.items(): - globals()[name] = plugin \ No newline at end of file +# for name, plugin in _installed_plugins.items(): +# globals()[name] = plugin \ No newline at end of file diff --git a/papercast/pipelines.py b/papercast/pipelines.py index 8d5b2d9..efb4751 100644 --- a/papercast/pipelines.py +++ b/papercast/pipelines.py @@ -1,7 +1,9 @@ -from papercast.base import BaseCollector -from papercast.base import BasePipelineComponent +from papercast.base import BaseProcessor, BaseSubscriber, BasePipelineComponent +from papercast.production import Production from typing import Iterable, Dict, Any from collections import defaultdict +import asyncio +from concurrent.futures import ThreadPoolExecutor class Pipeline: @@ -10,7 +12,9 @@ def __init__(self, name: str): self.connections = defaultdict(list) self.processors = {} self.collectors = {} + self.subscribers = {} self.downstream_processors = {} + self.executor = ThreadPoolExecutor() def _validate_name(self, name: str): if name in [p.name for p in self.processors.values()]: @@ -19,9 +23,15 @@ def _validate_name(self, name: str): def add_processor(self, name: str, processor: BasePipelineComponent): self._validate_name(name) setattr(processor, "name", name) + self.processors[name] = processor - if isinstance(processor, BaseCollector): + + if isinstance(processor, BaseProcessor): self.collectors[name] = processor + + elif isinstance(processor, BaseSubscriber): + self.subscribers[name] = processor + else: self.downstream_processors[name] = processor @@ -61,24 +71,9 @@ def input_types(self) -> Dict[str, Any]: input_types.update(processor.input_types) return input_types - # def _check_one_connected_component(self): - # not_connected = set() - # for processor in self.processors: - # if processor not in self.downstream_processors: - # not_connected.add(processor) - - # raise ValueError( - # f"Found {len(not_connected)} processors that are not connected to the pipeline: {not_connected}" - # ) - - # def validate(self): - # self._check_one_connected_component() - def _validate_run_kwargs(self, kwargs): input_kwargs = {k: v for k, v in kwargs.items() if k in self.input_types} - options_kwargs = { - k: v for k, v in kwargs.items() if k not in self.input_types - } + options_kwargs = {k: v for k, v in kwargs.items() if k not in self.input_types} if len(input_kwargs) != 1: raise ValueError( @@ -94,15 +89,17 @@ def _validate_run_kwargs(self, kwargs): return collector[0], collector[1], input_key, input_value, options_kwargs - def get_downstream_processors(self, collector_name: str) -> Iterable[str]: + def get_downstream_processors( + self, collector_subscriber_name: str + ) -> Iterable[str]: """Get all processors downstream of the collector with name `collector_name` by recursively traversing the graph of connections. """ downstream_processors = set() - if not self.connections[collector_name]: + if not self.connections[collector_subscriber_name]: raise ValueError( - f"Collector {collector_name} is not connected to any downstream processors" + f"Processor {collector_subscriber_name} is not connected to any downstream processors" ) def visit(processor_name: str): @@ -111,15 +108,40 @@ def visit(processor_name: str): for _, next_processor_name, _ in self.connections[processor_name]: visit(next_processor_name) - for _, downstream_processor, _ in self.connections[collector_name]: + for _, downstream_processor, _ in self.connections[collector_subscriber_name]: visit(downstream_processor) return downstream_processors - def run(self, **kwargs): - # self.validate() - # check that only one of the kwargs corresponds to an input of one of the input processors + async def _run_subscriber(self, subscriber_name: str): + subscriber = self.subscribers[subscriber_name] + loop = asyncio.get_event_loop() + processing_graph = self.get_downstream_processors(subscriber_name) + sorted_processors = self._topological_sort(processing_graph) + async for production in subscriber.subscribe(): + await loop.run_in_executor( + None, self.process, production, sorted_processors + ) + + async def _run_in_server(self): + await asyncio.gather(*[self._run_subscriber(name) for name in self.subscribers]) + + def process( + self, production: Production, collector_subscriber_name: str, **options + ) -> None: + """ + Run the pipeline synchronously on a production, from a collector or subscriber. + """ + print(f"Processing production {production}...") + processing_graph = self.get_downstream_processors(collector_subscriber_name) + sorted_processors = self._topological_sort(processing_graph) + for name in sorted_processors: + production = self.processors[name].process(production, **options) + def run(self, **kwargs): + """ + Run the pipeline synchronously, from kwargs. + """ ( collector_name, collector, @@ -128,11 +150,5 @@ def run(self, **kwargs): options, ) = self._validate_run_kwargs(kwargs) - processing_graph = self.get_downstream_processors(collector_name) - - sorted_processors = self._topological_sort(processing_graph) - production = collector.process(**{param: value}, **options) - - for name in sorted_processors: - production = self.processors[name].process(production, **options) + self.process(production, collector_subscriber_name=collector_name, **options) diff --git a/papercast/plugin_utils.py b/papercast/plugin_utils.py index c12b7e0..ff46b15 100644 --- a/papercast/plugin_utils.py +++ b/papercast/plugin_utils.py @@ -23,11 +23,11 @@ def load_plugins(plugin_type: str): for entry_point in importlib.metadata.entry_points().get(f'papercast.{plugin_type}', []): plugin_module = entry_point.load() - if plugin_type == "collectors": - validate_process_method(plugin_module) - validate_output_types(plugin_module) + # if plugin_type == "collectors": + # validate_process_method(plugin_module) + # validate_output_types(plugin_module) - elif plugin_type == "subscribers": + if plugin_type == "subscribers": validate_base_pipeline_component(plugin_module) validate_output_types(plugin_module) diff --git a/papercast/scripts/papercast.py b/papercast/scripts/papercast.py index ef8af90..2c13664 100644 --- a/papercast/scripts/papercast.py +++ b/papercast/scripts/papercast.py @@ -23,8 +23,8 @@ def parse_arguments(): key = arg[2:] if key: params[key] = [] - elif key in params: - params[key].append(arg) + elif key in params: # type: ignore + params[key].append(arg) # type: ignore else: print(f"Unexpected parameter {arg}.") sys.exit(1) diff --git a/papercast/server.py b/papercast/server.py index e7f9e8c..e7c79ea 100644 --- a/papercast/server.py +++ b/papercast/server.py @@ -1,64 +1,76 @@ from fastapi import FastAPI, Body -from typing import Optional, Dict, Any +from typing import Dict, Any from papercast.pipelines import Pipeline from fastapi import HTTPException, APIRouter import uvicorn - +import asyncio class Server: def __init__(self, pipelines: Dict[str, Pipeline]): - self._pipelines: Dict[str, Pipeline] = pipelines - self._init_app() + self._pipelines = pipelines + self._pipeline_tasks = [] - def _init_app(self): self.router = APIRouter() self.router.add_api_route("/", self._root) self.router.add_api_route("/add", self._add, methods=["POST"]) - self.router.add_api_route("/pipelines", self.__pipelines) + self.router.add_api_route("/pipelines", self.serialize_pipelines) + self.app = FastAPI() self.app.include_router(self.router) + self.app.add_event_handler("startup", self.run_pipelines) + self.app.add_event_handler("shutdown", self._cancel_pipeline_tasks) def _root(self): return {"message": "Papercast Server"} + def _get_pipeline(self, pipeline: str): + if pipeline not in self._pipelines.keys(): + raise HTTPException(status_code=404, detail="Pipeline not found") + return self._pipelines[pipeline] + def _add( self, data: Dict[Any, Any] = Body(...), ): - pipeline = self.get_pipeline(data["pipeline"]) # type: Pipeline + pipeline = self._get_pipeline(data["pipeline"]) # type: Pipeline pipeline.run(**data) - return {"message": f"Documents added to pipeline {pipeline.name}"} + return {"message": f"Document(s) added to pipeline {pipeline.name}"} - def get_pipeline(self, pipeline: str): - if pipeline not in self._pipelines.keys(): - raise HTTPException(status_code=404, detail="Pipeline not found") - return self._pipelines[pipeline] + def serialize_pipelines(self): + def serialize_pipeline(pipeline: Pipeline): + return { + # "collectors": [collector.asdict() for collector in pipeline.collectors], + "subscribers": [extractor.asdict() for extractor in pipeline.subscribers], + "processors": [narrator.asdict() for narrator in pipeline.processors], + } - def __pipelines(self): return { "pipelines": { - k: self.serialize_pipeline(p) for k, p in self._pipelines.items() + k: serialize_pipeline(p) for k, p in self._pipelines.items() } } - def serialize_pipeline(self, pipeline: Pipeline): - return { - "collectors": [collector.asdict() for collector in pipeline.collectors], - "narrators": [narrator.asdict() for narrator in pipeline.narrators], - "extractors": [extractor.asdict() for extractor in pipeline.extractors], - "filters": [filter.asdict() for filter in pipeline.filters], - "publishers": [publisher.asdict() for publisher in pipeline.publishers], - } + async def _cancel_pipeline_tasks(self): + for task in self._pipeline_tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass - def run( - self, - host: str = "", - port: int = 8000, - ): + async def run_pipelines(self): + print("Running pipelines") + for pipeline in self._pipelines.values(): + task = asyncio.create_task(pipeline._run_in_server()) + self._pipeline_tasks.append(task) + + def run(self, host: str = "", port: int = 8000): uvicorn.run( self.app, host=host, port=port, log_level="debug", + lifespan="on", ) + \ No newline at end of file diff --git a/papercast/subscribers/websocket_subscriber.py b/papercast/subscribers/websocket_subscriber.py new file mode 100644 index 0000000..5c6301e --- /dev/null +++ b/papercast/subscribers/websocket_subscriber.py @@ -0,0 +1,27 @@ +from papercast.base import BaseSubscriber +from abc import ABC, abstractmethod +from papercast.production import Production +from typing import AsyncIterable +from websockets.client import connect +import json + +class WebSocketSubscriber(BaseSubscriber): + def __init__(self, url) -> None: + super().__init__() + self.url = url + + def process(self, input: Production, *args, **kwargs) -> Production: # TODO: might not need this + return input + + # @abstractmethod + def process_message(self, message) -> Production: + # process the message and return a Production object + message = json.loads(message) + print(message) + return self.process(Production(**message)) + + async def subscribe(self) -> AsyncIterable[Production]: + async with connect(self.url) as socket: + print("connected") + async for message in socket: + yield self.process_message(message) diff --git a/papercast/types.py b/papercast/types.py index 45107d8..738deae 100644 --- a/papercast/types.py +++ b/papercast/types.py @@ -5,6 +5,12 @@ from dataclasses import dataclass import logging +from papercast.plugin_utils import load_plugins + +_installed_plugins = load_plugins("types") + +for name, plugin in _installed_plugins.items(): + globals()[name] = plugin PathLike = Union[str, Path] diff --git a/setup.py b/setup.py index a62d206..d81198d 100644 --- a/setup.py +++ b/setup.py @@ -1,15 +1,15 @@ -from setuptools import setup +from setuptools import setup, find_packages setup( name="papercast", version="0.1", - py_modules=["papercast"], install_requires=[ + "fastapi", + "uvicorn", ], + packages= find_packages(), entry_points=""" [console_scripts] - papercast-legacy=papercast.scripts.papercast_legacy:papercast_legacy papercast=papercast.scripts.papercast:main - ss=papercast.scripts.ss:ss """, )