-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #176 from tsugumi-sys/feature/imple-wiki-sp500-pip…
…eline Pipeline: Implementing wiki sp500
- Loading branch information
Showing
5 changed files
with
98 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,16 @@ | ||
import posixpath | ||
|
||
from stocklake.exceptions import StockLakeException | ||
from stocklake.stores.constants import StoreType | ||
|
||
|
||
def path_not_unique(name: str): | ||
norm = posixpath.normpath(name) | ||
return norm != name or norm == "." or norm.startswith("..") or norm.startswith("/") | ||
|
||
|
||
def validate_store_type(store_type: str | None): | ||
if store_type not in StoreType.types(): | ||
raise StockLakeException( | ||
f"Specified store type is invalid, {store_type}, valid types are {StoreType.types()}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import logging | ||
from typing import Optional | ||
|
||
from stocklake.core.base_pipeline import BasePipeline | ||
from stocklake.core.stdout import PipelineStdOut | ||
from stocklake.stores.constants import StoreType | ||
from stocklake.stores.db.database import DATABASE_SESSION_TYPE, local_session | ||
from stocklake.utils.validation import validate_store_type | ||
from stocklake.wiki_sp500.data_loader import WikiSP500DataLoader | ||
from stocklake.wiki_sp500.preprocessor import WikiSP500Preprocessor | ||
from stocklake.wiki_sp500.stores import WikiSP500Store | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class WikiSP500Pipeline(BasePipeline): | ||
def __init__( | ||
self, | ||
skip_download: bool = False, | ||
store_type: StoreType = StoreType.LOCAL_ARTIFACT, | ||
sqlalchemy_session: Optional[DATABASE_SESSION_TYPE] = None, | ||
): | ||
self.skip_download = skip_download | ||
|
||
validate_store_type(store_type) | ||
self.store_type = store_type | ||
if sqlalchemy_session is None: | ||
sqlalchemy_session = local_session() | ||
|
||
self.data_loader = ( | ||
WikiSP500DataLoader(use_cache=True) | ||
if self.skip_download | ||
else WikiSP500DataLoader() | ||
) | ||
self.preprocessor = WikiSP500Preprocessor() | ||
self.store = WikiSP500Store(sqlalchemy_session) | ||
self.stdout = PipelineStdOut() | ||
|
||
def run(self): | ||
self.stdout.starting("Wikipedia S&P500") | ||
if self.skip_download: | ||
self.stdout.skip_downloading() | ||
else: | ||
self.stdout.downloading() | ||
raw_data = self.data_loader.download() | ||
data = self.preprocessor.process(raw_data) | ||
saved_location = self.store.save(self.store_type, data) | ||
self.stdout.completed(saved_location) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from stocklake.exceptions import StockLakeException | ||
from stocklake.stores.constants import StoreType | ||
from stocklake.stores.db.models import WikiSP500Data | ||
from stocklake.wiki_sp500.pipeline import WikiSP500Pipeline | ||
from stocklake.wiki_sp500.stores import SAVE_ARTIFACTS_DIR | ||
|
||
|
||
def test_invalid_store_type_specified(): | ||
with pytest.raises(StockLakeException) as exc: | ||
_ = WikiSP500Pipeline(store_type="INVALID_STORE_TYPE") | ||
assert "Specified store type is invalid, INVALID_STORE_TYPE" in str(exc.value) | ||
|
||
|
||
def test_run_with_local_artifact(): | ||
pipeline = WikiSP500Pipeline(store_type=StoreType.LOCAL_ARTIFACT) | ||
pipeline.run() | ||
assert os.path.exists(os.path.join(SAVE_ARTIFACTS_DIR, "wiki_sp500.csv")) | ||
|
||
|
||
def test_run_with_postresql(SessionLocal): | ||
with SessionLocal() as session, session.begin(): | ||
res = session.query(WikiSP500Data).all() | ||
assert len(res) == 0 | ||
|
||
pipeline = WikiSP500Pipeline(store_type=StoreType.POSTGRESQL) | ||
pipeline.run() | ||
with SessionLocal() as session, session.begin(): | ||
res = session.query(WikiSP500Data).all() | ||
assert len(res) > 0 |