Skip to content

Commit

Permalink
Merge pull request #175 from tsugumi-sys/feature/imple-pipeline-wiki-…
Browse files Browse the repository at this point in the history
…spo500

stores: implemente postgresql store
  • Loading branch information
tsugumi-sys authored Jul 7, 2024
2 parents 3baf777 + 9f5284e commit 373b82c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
6 changes: 5 additions & 1 deletion stocklake/wiki_sp500/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from stocklake.stores.db import models
from stocklake.stores.db.database import (
DATABASE_SESSION_TYPE,
database_url,
local_session,
)
from stocklake.utils.file_utils import save_data_to_csv
Expand All @@ -35,7 +36,10 @@ def save(
repository.save_artifact(csv_file_path)
return repository.list_artifacts()[0].path
elif store_type == StoreType.POSTGRESQL:
raise NotImplementedError()
store = WikiSP500DataSQLAlchemyStore(self.sqlalchemy_session)
store.delete()
store.create([entities.WikiSP500DataCreate(**d.model_dump()) for d in data])
return os.path.join(database_url(), models.WikiSP500Data.__tablename__)
else:
raise NotImplementedError()

Expand Down
15 changes: 14 additions & 1 deletion tests/wiki_sp500/test_stores.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import os # noqa: I001

import pytest

from stocklake.stores.db.database import (
database_url,
)
from stocklake.stores.constants import StoreType
from stocklake.stores.db import models
from stocklake.wiki_sp500 import entities
Expand All @@ -23,6 +26,16 @@ def test_save_local_artifact_repo(wiki_sp500_data):
assert os.path.exists(saved_path)


def test_save_postgresql(wiki_sp500_data, SessionLocal):
store = WikiSP500Store(SessionLocal)
saved_path = store.save(StoreType.POSTGRESQL, wiki_sp500_data)
assert saved_path == os.path.join(
database_url(), models.WikiSP500Data.__tablename__
)
with SessionLocal() as session, session.begin():
assert len(session.query(models.WikiSP500Data).all()) == len(wiki_sp500_data)


def test_WikiSP500DataSQLAlchemyStore_delete(wiki_sp500_data, SessionLocal):
store = WikiSP500DataSQLAlchemyStore(SessionLocal)
store.create(
Expand Down

0 comments on commit 373b82c

Please sign in to comment.