Skip to content

Commit

Permalink
Merge pull request #197 from tsugumi-sys/feature/imple-local-artifact…
Browse files Browse the repository at this point in the history
…-store

Stores: Implement local artifact store for polygonapi aggregates bars
  • Loading branch information
tsugumi-sys authored Jul 14, 2024
2 parents 4f38da7 + 94b6d34 commit f331a49
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 2 deletions.
5 changes: 5 additions & 0 deletions stocklake/polygonapi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os

from stocklake.core.constants import DATA_DIR

BASE_SAVE_ARTIFACTS_DIR = os.path.join(DATA_DIR, "polygonapi")
39 changes: 39 additions & 0 deletions stocklake/polygonapi/aggregates_bars/stores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import tempfile
from typing import List, Optional

from stocklake.core.base_store import BaseStore
from stocklake.polygonapi import BASE_SAVE_ARTIFACTS_DIR
from stocklake.polygonapi.aggregates_bars import entities
from stocklake.stores.artifact.local_artifact_repo import LocalArtifactRepository
from stocklake.stores.constants import StoreType
from stocklake.stores.db.database import (
DATABASE_SESSION_TYPE,
local_session,
)
from stocklake.utils.file_utils import save_data_to_csv

SAVE_ARTIFACTS_DIR = os.path.join(BASE_SAVE_ARTIFACTS_DIR, "aggregates_bars")


class PolygonAggregatesBarsDataStore(BaseStore):
def __init__(self, sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None):
if sqlalchemy_session is None:
sqlalchemy_session = local_session()
self.sqlalchemy_session = sqlalchemy_session

def save(
self,
store_type: StoreType,
data: List[entities.PreprocessedPolygonAggregatesBarsData],
) -> str:
if store_type == StoreType.LOCAL_ARTIFACT:
repository = LocalArtifactRepository(SAVE_ARTIFACTS_DIR)
with tempfile.TemporaryDirectory() as tempdir:
csv_file_path = os.path.join(tempdir, "aggregates_bars.csv")
save_data_to_csv([d.model_dump() for d in data], csv_file_path)
repository.save_artifact(csv_file_path)
return repository.list_artifacts()[0].path
# elif store_type == StoreType.POSTGRESQL:
else:
raise NotImplementedError()
4 changes: 2 additions & 2 deletions stocklake/polygonapi/stock_financials_vx/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from stocklake.core.base_sqlalchemy_store import SQLAlchemyStore
from stocklake.core.base_store import BaseStore
from stocklake.core.constants import DATA_DIR
from stocklake.polygonapi import BASE_SAVE_ARTIFACTS_DIR
from stocklake.polygonapi.stock_financials_vx import entities
from stocklake.stores.artifact.local_artifact_repo import LocalArtifactRepository
from stocklake.stores.constants import StoreType
Expand All @@ -16,7 +16,7 @@
)
from stocklake.utils.file_utils import save_data_to_csv

SAVE_ARTIFACTS_DIR = os.path.join(DATA_DIR, "polygonapi")
SAVE_ARTIFACTS_DIR = os.path.join(BASE_SAVE_ARTIFACTS_DIR, "stock_financials_vx")


class PolygonFinancialsDataStore(BaseStore):
Expand Down
37 changes: 37 additions & 0 deletions tests/polygonapi/aggregates_bars/test_stores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os

import pytest

from stocklake.environment_variables import STOCKLAKE_POLYGON_API_KEY
from stocklake.polygonapi.aggregates_bars.data_loader import (
PolygonAggregatesBarsDataLoader,
)
from stocklake.polygonapi.aggregates_bars.preprocessor import (
PolygonAggregatesBarsPreprocessor,
)
from stocklake.polygonapi.aggregates_bars.stores import (
SAVE_ARTIFACTS_DIR,
PolygonAggregatesBarsDataStore,
)
from stocklake.stores.constants import StoreType
from tests.polygonapi.aggregates_bars.test_data_loader import (
MockPolygonAggregatesBarsAPIServer, # noqa: F401
)


@pytest.fixture
def polygon_aggregates_bars_data(
MockPolygonAggregatesBarsAPIServer, # noqa: F811
monkeypatch,
):
monkeypatch.setenv(STOCKLAKE_POLYGON_API_KEY.env_name, "dummy_key")
dataloader = PolygonAggregatesBarsDataLoader()
preprocessor = PolygonAggregatesBarsPreprocessor()
data = preprocessor.process(dataloader.download(["AAPL"]))
yield data


def test_polygon_aggregates_bars_store_local_artifact(polygon_aggregates_bars_data):
store = PolygonAggregatesBarsDataStore()
store.save(StoreType.LOCAL_ARTIFACT, polygon_aggregates_bars_data)
assert os.path.exists(os.path.join(SAVE_ARTIFACTS_DIR, "aggregates_bars.csv"))

0 comments on commit f331a49

Please sign in to comment.