Skip to content

Commit

Permalink
Store: Implement local artifact store for wiki sp500
Browse files Browse the repository at this point in the history
  • Loading branch information
tsugumi-sys committed Jul 7, 2024
1 parent 9b2261d commit 1202673
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 26 deletions.
7 changes: 2 additions & 5 deletions stocklake/nasdaqapi/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
from typing import Optional

from sqlalchemy import orm

from stocklake.core.base_data_loader import BaseDataLoader
from stocklake.core.base_pipeline import BasePipeline
from stocklake.core.base_preprocessor import BasePreprocessor
Expand All @@ -18,7 +16,7 @@
)
from stocklake.nasdaqapi.stores import NASDAQDataStore
from stocklake.stores.constants import StoreType
from stocklake.stores.db.database import local_session
from stocklake.stores.db.database import DATABASE_SESSION_TYPE, local_session

logger = logging.getLogger(__name__)

Expand All @@ -29,9 +27,8 @@ def __init__(
skip_download: bool = False,
exchange: Optional[Exchange] = None,
store_type: StoreType = StoreType.LOCAL_ARTIFACT,
sqlalchemy_session: Optional[orm.sessionmaker[orm.session.Session]] = None,
sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None,
):
print(exchange)
if exchange is not None and exchange not in Exchange.exchanges():
raise StockLakeException(
f"Specified exchange is invalid, but got {exchange}. The valid exchanges are {Exchange.exchanges()}"
Expand Down
13 changes: 4 additions & 9 deletions stocklake/nasdaqapi/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import tempfile
from typing import List, Optional

from sqlalchemy import orm

from stocklake.core.base_sqlalchemy_store import SQLAlchemyStore
from stocklake.core.base_store import BaseStore
from stocklake.core.constants import DATA_DIR
Expand All @@ -14,6 +12,7 @@
from stocklake.stores.constants import StoreType
from stocklake.stores.db import models, schemas
from stocklake.stores.db.database import (
DATABASE_SESSION_TYPE,
local_session,
safe_database_url_from_sessionmaker,
)
Expand All @@ -23,9 +22,7 @@


class NASDAQDataStore(BaseStore):
def __init__(
self, sqlalchemy_session: Optional[orm.sessionmaker[orm.session.Session]] = None
):
def __init__(self, sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None):
if sqlalchemy_session is None:
sqlalchemy_session = local_session()
self.sqlalchemy_session = sqlalchemy_session
Expand Down Expand Up @@ -54,13 +51,11 @@ def save(
models.NasdaqApiData.__tablename__,
)
else:
raise NotImplementedError
raise NotImplementedError()


class NasdaqApiSQLAlchemyStore(SQLAlchemyStore):
def __init__(
self, exchange: Exchange, session: orm.sessionmaker[orm.session.Session]
):
def __init__(self, exchange: Exchange, session: DATABASE_SESSION_TYPE):
self.exchange = exchange
self.session = session

Expand Down
2 changes: 1 addition & 1 deletion stocklake/polygonapi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ def polygonapi(
symbols.split(","), skip_download, store_type
)
else:
raise NotImplementedError
raise NotImplementedError()
pipeline.run()
6 changes: 2 additions & 4 deletions stocklake/polygonapi/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
from typing import List, Optional

from sqlalchemy import orm

from stocklake.core.base_data_loader import BaseDataLoader
from stocklake.core.base_pipeline import BasePipeline
from stocklake.core.base_preprocessor import BasePreprocessor
Expand All @@ -17,7 +15,7 @@
)
from stocklake.polygonapi.stores import PolygonFinancialsDataStore
from stocklake.stores.constants import StoreType
from stocklake.stores.db.database import local_session
from stocklake.stores.db.database import DATABASE_SESSION_TYPE, local_session

logger = logging.getLogger(__name__)

Expand All @@ -28,7 +26,7 @@ def __init__(
symbols: List[str],
skip_download: bool = False,
store_type: StoreType = StoreType.LOCAL_ARTIFACT,
sqlalchemy_session: Optional[orm.sessionmaker[orm.session.Session]] = None,
sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None,
):
self.symbols = symbols
self.skip_download = skip_download
Expand Down
11 changes: 4 additions & 7 deletions stocklake/polygonapi/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import tempfile
from typing import List, Optional

from sqlalchemy import orm

from stocklake.core.base_sqlalchemy_store import SQLAlchemyStore
from stocklake.core.base_store import BaseStore
from stocklake.core.constants import DATA_DIR
from stocklake.stores.artifact.local_artifact_repo import LocalArtifactRepository
from stocklake.stores.constants import StoreType
from stocklake.stores.db import models, schemas
from stocklake.stores.db.database import (
DATABASE_SESSION_TYPE,
local_session,
safe_database_url_from_sessionmaker,
)
Expand All @@ -20,9 +19,7 @@


class PolygonFinancialsDataStore(BaseStore):
def __init__(
self, sqlalchemy_session: Optional[orm.sessionmaker[orm.session.Session]] = None
):
def __init__(self, sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None):
if sqlalchemy_session is None:
sqlalchemy_session = local_session()
self.sqlalchemy_session = sqlalchemy_session
Expand Down Expand Up @@ -50,11 +47,11 @@ def save(
models.PolygonFinancialsData.__tablename__,
)
else:
raise NotImplementedError
raise NotImplementedError()


class PolygonFinancialsDataSQLAlchemyStore(SQLAlchemyStore):
def __init__(self, session: orm.sessionmaker[orm.session.Session]):
def __init__(self, session: DATABASE_SESSION_TYPE):
self.session = session

def create(self, data: List[schemas.PolygonFinancialsDataCreate]):
Expand Down
2 changes: 2 additions & 0 deletions stocklake/stores/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

Base = orm.declarative_base()

DATABASE_SESSION_TYPE = orm.sessionmaker[orm.session.Session]


def database_url():
"""Dynamically change database url based on environment variable `__STOCKLAKE_ENVIRONMENT`"""
Expand Down
36 changes: 36 additions & 0 deletions stocklake/wiki_sp500/stores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import tempfile
from typing import List, Optional

from stocklake.core.base_store import BaseStore
from stocklake.core.constants import DATA_DIR
from stocklake.stores.artifact.local_artifact_repo import LocalArtifactRepository
from stocklake.stores.constants import StoreType
from stocklake.stores.db.database import (
DATABASE_SESSION_TYPE,
local_session,
)
from stocklake.utils.file_utils import save_data_to_csv
from stocklake.wiki_sp500.entities import PreprocessedWikiSp500Data

SAVE_ARTIFACTS_DIR = os.path.join(DATA_DIR, "wiki_sp500")


class WikiSP500Stores(BaseStore):
def __init__(self, sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None):
if sqlalchemy_session is None:
sqlalchemy_session = local_session()
self.sqlalchemy_session = sqlalchemy_session

def save(self, store_type: StoreType, data: List[PreprocessedWikiSp500Data]) -> str:
if store_type == StoreType.LOCAL_ARTIFACT:
repository = LocalArtifactRepository(SAVE_ARTIFACTS_DIR)
with tempfile.TemporaryDirectory() as tmpdir:
csv_file_path = os.path.join(tmpdir, "wiki_sp500.csv")
save_data_to_csv([d.model_dump() for d in data], csv_file_path)
repository.save_artifact(csv_file_path)
return repository.list_artifacts()[0].path
elif store_type == StoreType.POSTGRESQL:
raise NotImplementedError()
else:
raise NotImplementedError()
21 changes: 21 additions & 0 deletions tests/wiki_sp500/test_stores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os

import pytest

from stocklake.stores.constants import StoreType
from stocklake.wiki_sp500.data_loader import WikiSP500DataLoader
from stocklake.wiki_sp500.preprocessor import WikiSP500Preprocessor
from stocklake.wiki_sp500.stores import WikiSP500Stores


@pytest.fixture
def wiki_sp500_data():
data_loader = WikiSP500DataLoader()
preprocessor = WikiSP500Preprocessor()
yield preprocessor.process(data_loader.download())


def test_save_local_artifact_repo(wiki_sp500_data):
store = WikiSP500Stores()
saved_path = store.save(StoreType.LOCAL_ARTIFACT, wiki_sp500_data)
assert os.path.exists(saved_path)

0 comments on commit 1202673

Please sign in to comment.