diff --git a/stocklake/nasdaqapi/pipeline.py b/stocklake/nasdaqapi/pipeline.py index 77eb6fe..2bd8f3b 100644 --- a/stocklake/nasdaqapi/pipeline.py +++ b/stocklake/nasdaqapi/pipeline.py @@ -1,8 +1,6 @@ import logging from typing import Optional -from sqlalchemy import orm - from stocklake.core.base_data_loader import BaseDataLoader from stocklake.core.base_pipeline import BasePipeline from stocklake.core.base_preprocessor import BasePreprocessor @@ -18,7 +16,7 @@ ) from stocklake.nasdaqapi.stores import NASDAQDataStore from stocklake.stores.constants import StoreType -from stocklake.stores.db.database import local_session +from stocklake.stores.db.database import DATABASE_SESSION_TYPE, local_session logger = logging.getLogger(__name__) @@ -29,9 +27,8 @@ def __init__( skip_download: bool = False, exchange: Optional[Exchange] = None, store_type: StoreType = StoreType.LOCAL_ARTIFACT, - sqlalchemy_session: Optional[orm.sessionmaker[orm.session.Session]] = None, + sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None, ): - print(exchange) if exchange is not None and exchange not in Exchange.exchanges(): raise StockLakeException( f"Specified exchange is invalid, but got {exchange}. The valid exchanges are {Exchange.exchanges()}" diff --git a/stocklake/nasdaqapi/stores.py b/stocklake/nasdaqapi/stores.py index 5b3240f..0b91e70 100644 --- a/stocklake/nasdaqapi/stores.py +++ b/stocklake/nasdaqapi/stores.py @@ -2,8 +2,6 @@ import tempfile from typing import List, Optional -from sqlalchemy import orm - from stocklake.core.base_sqlalchemy_store import SQLAlchemyStore from stocklake.core.base_store import BaseStore from stocklake.core.constants import DATA_DIR @@ -14,6 +12,7 @@ from stocklake.stores.constants import StoreType from stocklake.stores.db import models, schemas from stocklake.stores.db.database import ( + DATABASE_SESSION_TYPE, local_session, safe_database_url_from_sessionmaker, ) @@ -23,9 +22,7 @@ class NASDAQDataStore(BaseStore): - def __init__( - self, sqlalchemy_session: Optional[orm.sessionmaker[orm.session.Session]] = None - ): + def __init__(self, sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None): if sqlalchemy_session is None: sqlalchemy_session = local_session() self.sqlalchemy_session = sqlalchemy_session @@ -54,13 +51,11 @@ def save( models.NasdaqApiData.__tablename__, ) else: - raise NotImplementedError + raise NotImplementedError() class NasdaqApiSQLAlchemyStore(SQLAlchemyStore): - def __init__( - self, exchange: Exchange, session: orm.sessionmaker[orm.session.Session] - ): + def __init__(self, exchange: Exchange, session: DATABASE_SESSION_TYPE): self.exchange = exchange self.session = session diff --git a/stocklake/polygonapi/cli.py b/stocklake/polygonapi/cli.py index e9f6961..df9ca4f 100644 --- a/stocklake/polygonapi/cli.py +++ b/stocklake/polygonapi/cli.py @@ -46,5 +46,5 @@ def polygonapi( symbols.split(","), skip_download, store_type ) else: - raise NotImplementedError + raise NotImplementedError() pipeline.run() diff --git a/stocklake/polygonapi/pipeline.py b/stocklake/polygonapi/pipeline.py index 7557af4..cc657fe 100644 --- a/stocklake/polygonapi/pipeline.py +++ b/stocklake/polygonapi/pipeline.py @@ -1,8 +1,6 @@ import logging from typing import List, Optional -from sqlalchemy import orm - from stocklake.core.base_data_loader import BaseDataLoader from stocklake.core.base_pipeline import BasePipeline from stocklake.core.base_preprocessor import BasePreprocessor @@ -17,7 +15,7 @@ ) from stocklake.polygonapi.stores import PolygonFinancialsDataStore from stocklake.stores.constants import StoreType -from stocklake.stores.db.database import local_session +from stocklake.stores.db.database import DATABASE_SESSION_TYPE, local_session logger = logging.getLogger(__name__) @@ -28,7 +26,7 @@ def __init__( symbols: List[str], skip_download: bool = False, store_type: StoreType = StoreType.LOCAL_ARTIFACT, - sqlalchemy_session: Optional[orm.sessionmaker[orm.session.Session]] = None, + sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None, ): self.symbols = symbols self.skip_download = skip_download diff --git a/stocklake/polygonapi/stores.py b/stocklake/polygonapi/stores.py index 6753220..8902b40 100644 --- a/stocklake/polygonapi/stores.py +++ b/stocklake/polygonapi/stores.py @@ -2,8 +2,6 @@ import tempfile from typing import List, Optional -from sqlalchemy import orm - from stocklake.core.base_sqlalchemy_store import SQLAlchemyStore from stocklake.core.base_store import BaseStore from stocklake.core.constants import DATA_DIR @@ -11,6 +9,7 @@ from stocklake.stores.constants import StoreType from stocklake.stores.db import models, schemas from stocklake.stores.db.database import ( + DATABASE_SESSION_TYPE, local_session, safe_database_url_from_sessionmaker, ) @@ -20,9 +19,7 @@ class PolygonFinancialsDataStore(BaseStore): - def __init__( - self, sqlalchemy_session: Optional[orm.sessionmaker[orm.session.Session]] = None - ): + def __init__(self, sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None): if sqlalchemy_session is None: sqlalchemy_session = local_session() self.sqlalchemy_session = sqlalchemy_session @@ -50,11 +47,11 @@ def save( models.PolygonFinancialsData.__tablename__, ) else: - raise NotImplementedError + raise NotImplementedError() class PolygonFinancialsDataSQLAlchemyStore(SQLAlchemyStore): - def __init__(self, session: orm.sessionmaker[orm.session.Session]): + def __init__(self, session: DATABASE_SESSION_TYPE): self.session = session def create(self, data: List[schemas.PolygonFinancialsDataCreate]): diff --git a/stocklake/stores/db/database.py b/stocklake/stores/db/database.py index cfe48c3..5c8f302 100644 --- a/stocklake/stores/db/database.py +++ b/stocklake/stores/db/database.py @@ -10,6 +10,8 @@ Base = orm.declarative_base() +DATABASE_SESSION_TYPE = orm.sessionmaker[orm.session.Session] + def database_url(): """Dynamically change database url based on environment variable `__STOCKLAKE_ENVIRONMENT`""" diff --git a/stocklake/wiki_sp500/stores.py b/stocklake/wiki_sp500/stores.py new file mode 100644 index 0000000..dc4a251 --- /dev/null +++ b/stocklake/wiki_sp500/stores.py @@ -0,0 +1,36 @@ +import os +import tempfile +from typing import List, Optional + +from stocklake.core.base_store import BaseStore +from stocklake.core.constants import DATA_DIR +from stocklake.stores.artifact.local_artifact_repo import LocalArtifactRepository +from stocklake.stores.constants import StoreType +from stocklake.stores.db.database import ( + DATABASE_SESSION_TYPE, + local_session, +) +from stocklake.utils.file_utils import save_data_to_csv +from stocklake.wiki_sp500.entities import PreprocessedWikiSp500Data + +SAVE_ARTIFACTS_DIR = os.path.join(DATA_DIR, "wiki_sp500") + + +class WikiSP500Stores(BaseStore): + def __init__(self, sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None): + if sqlalchemy_session is None: + sqlalchemy_session = local_session() + self.sqlalchemy_session = sqlalchemy_session + + def save(self, store_type: StoreType, data: List[PreprocessedWikiSp500Data]) -> str: + if store_type == StoreType.LOCAL_ARTIFACT: + repository = LocalArtifactRepository(SAVE_ARTIFACTS_DIR) + with tempfile.TemporaryDirectory() as tmpdir: + csv_file_path = os.path.join(tmpdir, "wiki_sp500.csv") + save_data_to_csv([d.model_dump() for d in data], csv_file_path) + repository.save_artifact(csv_file_path) + return repository.list_artifacts()[0].path + elif store_type == StoreType.POSTGRESQL: + raise NotImplementedError() + else: + raise NotImplementedError() diff --git a/tests/wiki_sp500/test_stores.py b/tests/wiki_sp500/test_stores.py new file mode 100644 index 0000000..43800fc --- /dev/null +++ b/tests/wiki_sp500/test_stores.py @@ -0,0 +1,21 @@ +import os + +import pytest + +from stocklake.stores.constants import StoreType +from stocklake.wiki_sp500.data_loader import WikiSP500DataLoader +from stocklake.wiki_sp500.preprocessor import WikiSP500Preprocessor +from stocklake.wiki_sp500.stores import WikiSP500Stores + + +@pytest.fixture +def wiki_sp500_data(): + data_loader = WikiSP500DataLoader() + preprocessor = WikiSP500Preprocessor() + yield preprocessor.process(data_loader.download()) + + +def test_save_local_artifact_repo(wiki_sp500_data): + store = WikiSP500Stores() + saved_path = store.save(StoreType.LOCAL_ARTIFACT, wiki_sp500_data) + assert os.path.exists(saved_path)