From 7e03a83513f24057a1b90a60974ff5cdbc53902a Mon Sep 17 00:00:00 2001 From: tsugumi-sys Date: Mon, 29 Apr 2024 13:05:33 +0900 Subject: [PATCH] Test: Adding unittest nasdaq pipelines --- stocklake/core/base_data_loader.py | 2 +- stocklake/core/base_pipeline.py | 2 +- stocklake/nasdaqapi/constants.py | 2 +- stocklake/nasdaqapi/data_loader.py | 14 ++--- stocklake/nasdaqapi/pipeline.py | 83 ++++++++++++----------------- stocklake/nasdaqapi/preprocessor.py | 6 +-- stocklake/polygonapi/data_loader.py | 4 +- tests/nasdaqapi/test_pipeline.py | 35 ++++++++++++ 8 files changed, 85 insertions(+), 63 deletions(-) diff --git a/stocklake/core/base_data_loader.py b/stocklake/core/base_data_loader.py index e5cbb05..7c2f061 100644 --- a/stocklake/core/base_data_loader.py +++ b/stocklake/core/base_data_loader.py @@ -3,7 +3,7 @@ from stocklake.stores.artifact.base import ArtifactRepository -class DataLoader(ABC): +class BaseDataLoader(ABC): def __init__(self, artifact_repo: ArtifactRepository): self._artifact_repo = artifact_repo diff --git a/stocklake/core/base_pipeline.py b/stocklake/core/base_pipeline.py index 5447bb0..d2e234d 100644 --- a/stocklake/core/base_pipeline.py +++ b/stocklake/core/base_pipeline.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod -class Pipeline(ABC): +class BasePipeline(ABC): @abstractmethod def run(self, *args, **kwargs): pass diff --git a/stocklake/nasdaqapi/constants.py b/stocklake/nasdaqapi/constants.py index ca56c39..bfe7696 100644 --- a/stocklake/nasdaqapi/constants.py +++ b/stocklake/nasdaqapi/constants.py @@ -8,4 +8,4 @@ class Exchange(str, Enum): @classmethod def exchanges(self): - return sorted([e for e in Exchange.__members__]) + return sorted([e.value for e in Exchange]) diff --git a/stocklake/nasdaqapi/data_loader.py b/stocklake/nasdaqapi/data_loader.py index f80a7e5..015f947 100644 --- a/stocklake/nasdaqapi/data_loader.py +++ b/stocklake/nasdaqapi/data_loader.py @@ -7,7 +7,7 @@ import requests -from stocklake.core.base_data_loader import DataLoader +from stocklake.core.base_data_loader import BaseDataLoader from stocklake.nasdaqapi.constants import Exchange from stocklake.stores.artifact.base import ArtifactRepository @@ -44,11 +44,11 @@ def symbols_api_endpoint(exchange_name: Exchange) -> str: return f"https://api.nasdaq.com/api/screener/stocks?tableonly=true&limit=25&offset=0&exchange={exchange_name}&download=true" -class NASDAQSymbolsDataLoader(DataLoader): +class NASDAQSymbolsDataLoader(BaseDataLoader): def __init__( self, artifact_repo: ArtifactRepository, - artifact_filename_json: str = "data.json", + artifact_filename_json: str = "raw_nasdaq_data.json", ): super().__init__(artifact_repo) self.artifact_filename_json = artifact_filename_json @@ -83,11 +83,11 @@ def download(self): self.artifact_repo.save_artifact(local_file) -class NYSESymbolsDataLoader(DataLoader): +class NYSESymbolsDataLoader(BaseDataLoader): def __init__( self, artifact_repo: ArtifactRepository, - artifact_filename_json: str = "data.json", + artifact_filename_json: str = "raw_nyse_data.json", ): super().__init__(artifact_repo) self.artifact_filename_json = artifact_filename_json @@ -121,11 +121,11 @@ def download(self): self.artifact_repo.save_artifact(local_file) -class AMEXSymbolsDataLoader(DataLoader): +class AMEXSymbolsDataLoader(BaseDataLoader): def __init__( self, artifact_repo: ArtifactRepository, - artifact_filename_json: str = "data.json", + artifact_filename_json: str = "raw_amex_data.json", ): super().__init__(artifact_repo) self.artifact_filename_json = artifact_filename_json diff --git a/stocklake/nasdaqapi/pipeline.py b/stocklake/nasdaqapi/pipeline.py index 5c2b424..f96aedf 100644 --- a/stocklake/nasdaqapi/pipeline.py +++ b/stocklake/nasdaqapi/pipeline.py @@ -1,8 +1,11 @@ import logging import os -from typing import Optional +from typing import Any, Optional -from stocklake.core.base_pipeline import Pipeline +# from stocklake.core.base_data_loader import BaseDataLoader +from stocklake.core.base_pipeline import BasePipeline + +# from stocklake.core.base_preprocessor import BasePreprocessor from stocklake.core.constants import DATA_DIR from stocklake.core.stdout import PrettyStdoutPrint from stocklake.exceptions import StockLoaderException @@ -17,18 +20,20 @@ NASDAQSymbolsPreprocessor, NYSESymbolsPreprocessor, ) +from stocklake.stores.artifact.base import ArtifactRepository from stocklake.stores.artifact.local_artifact_repo import LocalArtifactRepository from stocklake.stores.constants import StoreType logger = logging.getLogger(__name__) -class NASDAQSymbolsPipeline(Pipeline): +class NASDAQSymbolsPipeline(BasePipeline): def __init__( self, skip_download: bool = False, exchange: Optional[Exchange] = None, store_type: Optional[StoreType] = StoreType.LOCAL_ARTIFACT, + data_dir: str = DATA_DIR, ): self.exchange = exchange self.skip_download = skip_download @@ -38,67 +43,49 @@ def __init__( f"Specified store type is invalid, {store_type}, valid types are {StoreType.types()}" ) self.store_type = store_type - self.save_dir = os.path.join(DATA_DIR, "nasdaqapi") + self._save_dir = os.path.join(data_dir, "nasdaqapi") self.stdout = PrettyStdoutPrint() + @property + def save_dir_path(self) -> str: + return self._save_dir + def run(self): logger.info("{} NASDAQ pipeline starts {}".format("=" * 30, "=" * 30)) if self.exchange == Exchange.NASDAQ or self.exchange is None: self.stdout.step_start(f"{Exchange.NASDAQ} symbols with nasdapapi") exchange_repo = LocalArtifactRepository( - os.path.join(self.save_dir, Exchange.NASDAQ) - ) - downloader = NASDAQSymbolsDataLoader(exchange_repo, "raw_data.json") - if not self.skip_download: - self.stdout.normal_message("- Downloading ...") - downloader.download() - else: - self.stdout.warning_message("- Skip Downloading") - - preprocessor = NASDAQSymbolsPreprocessor( - exchange_repo, downloader.artifact_path, "processed.csv" - ) - preprocessor.process() - self.stdout.success_message( - f"- Completed🐳. The artifact is saved to {preprocessor.artifact_path}" + os.path.join(self._save_dir, Exchange.NASDAQ) ) + self._run(exchange_repo, NASDAQSymbolsDataLoader, NASDAQSymbolsPreprocessor) if self.exchange == Exchange.NYSE or self.exchange is None: self.stdout.step_start(f"{Exchange.NYSE} symbols with nasdapapi") exchange_repo = LocalArtifactRepository( - os.path.join(self.save_dir, Exchange.NASDAQ) - ) - downloader = NYSESymbolsDataLoader(exchange_repo, "raw_data.json") - if not self.skip_download: - self.stdout.normal_message("- Downloading ...") - downloader.download() - else: - self.stdout.warning_message("- Skip Downloading") - - preprocessor = NYSESymbolsPreprocessor( - exchange_repo, downloader.artifact_path, "processed.csv" - ) - preprocessor.process() - self.stdout.success_message( - f"- Completed🐳. The artifact is saved to {preprocessor.artifact_path}" + os.path.join(self._save_dir, Exchange.NASDAQ) ) + self._run(exchange_repo, NYSESymbolsDataLoader, NYSESymbolsPreprocessor) if self.exchange == Exchange.AMEX or self.exchange is None: self.stdout.step_start(f"{Exchange.AMEX} symbols with nasdapapi") exchange_repo = LocalArtifactRepository( - os.path.join(self.save_dir, Exchange.NASDAQ) + os.path.join(self._save_dir, Exchange.NASDAQ) ) - downloader = AMEXSymbolsDataLoader(exchange_repo, "raw_data.json") - if not self.skip_download: - self.stdout.normal_message("- Downloading ...") - downloader.download() - else: - self.stdout.warning_message("- Skip Downloading") + self._run(exchange_repo, AMEXSymbolsDataLoader, AMEXSymbolsPreprocessor) - preprocessor = AMEXSymbolsPreprocessor( - exchange_repo, downloader.artifact_path, "processed.csv" - ) - preprocessor.process() - self.stdout.success_message( - f"- Completed🐳. The artifact is saved to {preprocessor.artifact_path}" - ) + def _run( + self, + repository: ArtifactRepository, + data_loader: Any, + preprocessor: Any, + ): + _data_loader = data_loader(repository) + if not self.skip_download: + self.stdout.normal_message("- Downloading ...") + _data_loader.download() + else: + self.stdout.warning_message("- Skip Downloading") + preprocessor(repository, _data_loader.artifact_path).process() + self.stdout.success_message( + f"- Completed🐳. The artifact is saved to {preprocessor.artifact_path}" + ) diff --git a/stocklake/nasdaqapi/preprocessor.py b/stocklake/nasdaqapi/preprocessor.py index 394aa88..29c7a92 100644 --- a/stocklake/nasdaqapi/preprocessor.py +++ b/stocklake/nasdaqapi/preprocessor.py @@ -13,7 +13,7 @@ def __init__( self, artifact_repo: ArtifactRepository, source_artifact_path_json: str, - artifact_filename_csv: str = "data.csv", + artifact_filename_csv: str = "nasdaq_data.csv", ): super().__init__(artifact_repo) self.source_artifact_path_json = source_artifact_path_json @@ -45,7 +45,7 @@ def __init__( self, artifact_repo: ArtifactRepository, source_artifact_path_json: str, - artifact_filename_csv: str = "data.csv", + artifact_filename_csv: str = "nyse_data.csv", ): super().__init__(artifact_repo) self.source_artifact_path_json = source_artifact_path_json @@ -77,7 +77,7 @@ def __init__( self, artifact_repo: ArtifactRepository, source_artifact_path_json: str, - artifact_filename_csv: str = "data.csv", + artifact_filename_csv: str = "amex_data.csv", ): super().__init__(artifact_repo) self.source_artifact_path_json = source_artifact_path_json diff --git a/stocklake/polygonapi/data_loader.py b/stocklake/polygonapi/data_loader.py index c5aa6da..9f50f07 100644 --- a/stocklake/polygonapi/data_loader.py +++ b/stocklake/polygonapi/data_loader.py @@ -9,13 +9,13 @@ from polygon import RESTClient from polygon.rest.models.financials import Financials -from stocklake.data_loaders.base import DataLoader +from stocklake.data_loaders.base import BaseDataLoader from stocklake.stores.artifact.base import ArtifactRepository logger = logging.getLogger(__name__) -class PolygonFinancialsDataLoader(DataLoader): +class PolygonFinancialsDataLoader(BaseDataLoader): def __init__( self, artifact_repo: ArtifactRepository, diff --git a/tests/nasdaqapi/test_pipeline.py b/tests/nasdaqapi/test_pipeline.py index 6303357..73c3379 100644 --- a/tests/nasdaqapi/test_pipeline.py +++ b/tests/nasdaqapi/test_pipeline.py @@ -1,7 +1,13 @@ +import os +from unittest import mock + import pytest from stocklake.exceptions import StockLoaderException +from stocklake.nasdaqapi.constants import Exchange from stocklake.nasdaqapi.pipeline import NASDAQSymbolsPipeline +from stocklake.stores.constants import StoreType +from tests.nasdaqapi.test_data_loader import mock_requests_get def test_invalid_store_type_specified(): @@ -11,3 +17,32 @@ def test_invalid_store_type_specified(): str(exc.value) == "Specified store type is invalid, INVALID_STORE_TYPE, valid types are ['local_artifact', 'postgresql']" ) + + +@mock.patch("requests.get", side_effect=mock_requests_get) +@pytest.mark.parametrize("exchange_name", Exchange.exchanges()) +def test_run_each_symbols_with_local_artifact(mock_get, exchange_name, tmpdir): + pipeline = NASDAQSymbolsPipeline( + skip_download=False, + exchange=exchange_name, + store_type=StoreType.LOCAL_ARTIFACT, + data_dir=tmpdir, + ) + pipeline.run() + data_dir = os.path.join(pipeline.save_dir_path, "nasdaq") + assert os.path.exists(os.path.join(data_dir, f"{exchange_name}_data.csv")) + assert os.path.exists(os.path.join(data_dir, f"raw_{exchange_name}_data.json")) + + +@mock.patch("requests.get", side_effect=mock_requests_get) +def test_run_with_local_artifact(mock_get, tmpdir): + pipeline = NASDAQSymbolsPipeline( + skip_download=False, + store_type=StoreType.LOCAL_ARTIFACT, + data_dir=tmpdir, + ) + pipeline.run() + data_dir = os.path.join(pipeline.save_dir_path, "nasdaq") + for exchange_name in Exchange.exchanges(): + assert os.path.exists(os.path.join(data_dir, f"{exchange_name}_data.csv")) + assert os.path.exists(os.path.join(data_dir, f"raw_{exchange_name}_data.json"))