diff --git a/.python-version b/.python-version index d4b278f..455808f 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.11.7 +3.12.4 diff --git a/stocklake/polygonapi/pipeline.py b/stocklake/polygonapi/pipeline.py new file mode 100644 index 0000000..a27cac0 --- /dev/null +++ b/stocklake/polygonapi/pipeline.py @@ -0,0 +1,67 @@ +import logging +from typing import List + +from sqlalchemy import orm + +from stocklake.core.base_data_loader import BaseDataLoader +from stocklake.core.base_pipeline import BasePipeline +from stocklake.core.base_preprocessor import BasePreprocessor +from stocklake.core.base_store import BaseStore +from stocklake.core.stdout import PipelineStdOut +from stocklake.exceptions import StockLoaderException +from stocklake.polygonapi.data_loader import ( + PolygonFinancialsDataLoader, +) +from stocklake.polygonapi.preprocessor import ( + PolygonFinancialsDataPreprocessor, +) +from stocklake.polygonapi.stores import PolygonFinancialsDataStore +from stocklake.stores.constants import StoreType +from stocklake.stores.db.database import LocalSession # noqa: E402 + +logger = logging.getLogger(__name__) + + +class PolygonFinancialsDataPipeline(BasePipeline): + def __init__( + self, + symbols: List[str], + skip_download: bool = False, + store_type: StoreType = StoreType.LOCAL_ARTIFACT, + sqlalchemy_session: orm.sessionmaker[orm.session.Session] = LocalSession, + ): + self.symbols = symbols + self.skip_download = skip_download + + if store_type not in StoreType.types(): + raise StockLoaderException( + f"Specified store type is invalid, {store_type}, valid types are {StoreType.types()}" + ) + self.store_type = store_type + self.data_loader = PolygonFinancialsDataLoader() + self.preprocessor = PolygonFinancialsDataPreprocessor() + self.store = PolygonFinancialsDataStore(sqlalchemy_session) + self.stdout = PipelineStdOut() + + def run(self): + for symbol in self.symbols: + self._run(symbol, self.data_loader, self.preprocessor, self.store) + + def _run( + self, + symbol: str, + data_loader: BaseDataLoader, + preprocessor: BasePreprocessor, + store: BaseStore, + ): + self.stdout.starting(f"Polygon Finaqncials API of {symbol}") + if not self.skip_download: + self.stdout.downloading() + raw_data = data_loader.download(self.symbols) + else: + self.stdout.skip_downloading() + # TODO: fetch from cached file + return + data = preprocessor.process(raw_data) + store.save(self.store_type, data) + self.stdout.completed() diff --git a/tests/polygonapi/test_pipeline.py b/tests/polygonapi/test_pipeline.py new file mode 100644 index 0000000..07389fc --- /dev/null +++ b/tests/polygonapi/test_pipeline.py @@ -0,0 +1,51 @@ +import os + +import pytest + +from stocklake.exceptions import StockLoaderException +from stocklake.polygonapi.pipeline import PolygonFinancialsDataPipeline +from stocklake.polygonapi.stores import SAVE_ARTIFACTS_DIR +from stocklake.stores.constants import StoreType +from stocklake.stores.db.models import PolygonFinancialsData +from tests.polygonapi.test_data_loader import MockPolygonAPIServer # noqa: F401 +from tests.stores.db.utils import SessionLocal # noqa: F401 + + +def test_invalid_store_type_specified(): + with pytest.raises(StockLoaderException) as exc: + _ = PolygonFinancialsDataPipeline( + symbols=["MSFT"], store_type="INVALID_STORE_TYPE" + ) + assert ( + str(exc.value) + == "Specified store type is invalid, INVALID_STORE_TYPE, valid types are ['local_artifact', 'postgresql']" + ) + + +def test_run_with_local_artifact(MockPolygonAPIServer, monkeypatch): # noqa: F811 + monkeypatch.setenv("STOCKLAKE_POLYGON_API_KEY", "dummy_key") + pipeline = PolygonFinancialsDataPipeline( + symbols=["MSFT"], + skip_download=False, + store_type=StoreType.LOCAL_ARTIFACT, + ) + pipeline.run() + assert os.path.exists(os.path.join(SAVE_ARTIFACTS_DIR, "financials_data.csv")) + + +def test_run_with_postgresql( + MockPolygonAPIServer, # noqa: F811 + monkeypatch, + SessionLocal, # noqa: F811 +): + monkeypatch.setenv("STOCKLAKE_POLYGON_API_KEY", "dummy_key") + pipeline = PolygonFinancialsDataPipeline( + symbols=["MSFT"], + skip_download=False, + store_type=StoreType.POSTGRESQL, + sqlalchemy_session=SessionLocal, + ) + pipeline.run() + with SessionLocal() as session, session.begin(): + res = session.query(PolygonFinancialsData).all() + assert len(res) > 0