Skip to content

Commit

Permalink
Test: Adding unittest nasdaq pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
tsugumi-sys committed Apr 29, 2024
1 parent 641b2ff commit 7e03a83
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 63 deletions.
2 changes: 1 addition & 1 deletion stocklake/core/base_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from stocklake.stores.artifact.base import ArtifactRepository


class DataLoader(ABC):
class BaseDataLoader(ABC):
def __init__(self, artifact_repo: ArtifactRepository):
self._artifact_repo = artifact_repo

Expand Down
2 changes: 1 addition & 1 deletion stocklake/core/base_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod


class Pipeline(ABC):
class BasePipeline(ABC):
@abstractmethod
def run(self, *args, **kwargs):
pass
2 changes: 1 addition & 1 deletion stocklake/nasdaqapi/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ class Exchange(str, Enum):

@classmethod
def exchanges(self):
return sorted([e for e in Exchange.__members__])
return sorted([e.value for e in Exchange])
14 changes: 7 additions & 7 deletions stocklake/nasdaqapi/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import requests

from stocklake.core.base_data_loader import DataLoader
from stocklake.core.base_data_loader import BaseDataLoader
from stocklake.nasdaqapi.constants import Exchange
from stocklake.stores.artifact.base import ArtifactRepository

Expand Down Expand Up @@ -44,11 +44,11 @@ def symbols_api_endpoint(exchange_name: Exchange) -> str:
return f"https://api.nasdaq.com/api/screener/stocks?tableonly=true&limit=25&offset=0&exchange={exchange_name}&download=true"


class NASDAQSymbolsDataLoader(DataLoader):
class NASDAQSymbolsDataLoader(BaseDataLoader):
def __init__(
self,
artifact_repo: ArtifactRepository,
artifact_filename_json: str = "data.json",
artifact_filename_json: str = "raw_nasdaq_data.json",
):
super().__init__(artifact_repo)
self.artifact_filename_json = artifact_filename_json
Expand Down Expand Up @@ -83,11 +83,11 @@ def download(self):
self.artifact_repo.save_artifact(local_file)


class NYSESymbolsDataLoader(DataLoader):
class NYSESymbolsDataLoader(BaseDataLoader):
def __init__(
self,
artifact_repo: ArtifactRepository,
artifact_filename_json: str = "data.json",
artifact_filename_json: str = "raw_nyse_data.json",
):
super().__init__(artifact_repo)
self.artifact_filename_json = artifact_filename_json
Expand Down Expand Up @@ -121,11 +121,11 @@ def download(self):
self.artifact_repo.save_artifact(local_file)


class AMEXSymbolsDataLoader(DataLoader):
class AMEXSymbolsDataLoader(BaseDataLoader):
def __init__(
self,
artifact_repo: ArtifactRepository,
artifact_filename_json: str = "data.json",
artifact_filename_json: str = "raw_amex_data.json",
):
super().__init__(artifact_repo)
self.artifact_filename_json = artifact_filename_json
Expand Down
83 changes: 35 additions & 48 deletions stocklake/nasdaqapi/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import logging
import os
from typing import Optional
from typing import Any, Optional

from stocklake.core.base_pipeline import Pipeline
# from stocklake.core.base_data_loader import BaseDataLoader
from stocklake.core.base_pipeline import BasePipeline

# from stocklake.core.base_preprocessor import BasePreprocessor
from stocklake.core.constants import DATA_DIR
from stocklake.core.stdout import PrettyStdoutPrint
from stocklake.exceptions import StockLoaderException
Expand All @@ -17,18 +20,20 @@
NASDAQSymbolsPreprocessor,
NYSESymbolsPreprocessor,
)
from stocklake.stores.artifact.base import ArtifactRepository
from stocklake.stores.artifact.local_artifact_repo import LocalArtifactRepository
from stocklake.stores.constants import StoreType

logger = logging.getLogger(__name__)


class NASDAQSymbolsPipeline(Pipeline):
class NASDAQSymbolsPipeline(BasePipeline):
def __init__(
self,
skip_download: bool = False,
exchange: Optional[Exchange] = None,
store_type: Optional[StoreType] = StoreType.LOCAL_ARTIFACT,
data_dir: str = DATA_DIR,
):
self.exchange = exchange
self.skip_download = skip_download
Expand All @@ -38,67 +43,49 @@ def __init__(
f"Specified store type is invalid, {store_type}, valid types are {StoreType.types()}"
)
self.store_type = store_type
self.save_dir = os.path.join(DATA_DIR, "nasdaqapi")
self._save_dir = os.path.join(data_dir, "nasdaqapi")
self.stdout = PrettyStdoutPrint()

@property
def save_dir_path(self) -> str:
return self._save_dir

def run(self):
logger.info("{} NASDAQ pipeline starts {}".format("=" * 30, "=" * 30))
if self.exchange == Exchange.NASDAQ or self.exchange is None:
self.stdout.step_start(f"{Exchange.NASDAQ} symbols with nasdapapi")
exchange_repo = LocalArtifactRepository(
os.path.join(self.save_dir, Exchange.NASDAQ)
)
downloader = NASDAQSymbolsDataLoader(exchange_repo, "raw_data.json")
if not self.skip_download:
self.stdout.normal_message("- Downloading ...")
downloader.download()
else:
self.stdout.warning_message("- Skip Downloading")

preprocessor = NASDAQSymbolsPreprocessor(
exchange_repo, downloader.artifact_path, "processed.csv"
)
preprocessor.process()
self.stdout.success_message(
f"- Completed🐳. The artifact is saved to {preprocessor.artifact_path}"
os.path.join(self._save_dir, Exchange.NASDAQ)
)
self._run(exchange_repo, NASDAQSymbolsDataLoader, NASDAQSymbolsPreprocessor)

if self.exchange == Exchange.NYSE or self.exchange is None:
self.stdout.step_start(f"{Exchange.NYSE} symbols with nasdapapi")
exchange_repo = LocalArtifactRepository(
os.path.join(self.save_dir, Exchange.NASDAQ)
)
downloader = NYSESymbolsDataLoader(exchange_repo, "raw_data.json")
if not self.skip_download:
self.stdout.normal_message("- Downloading ...")
downloader.download()
else:
self.stdout.warning_message("- Skip Downloading")

preprocessor = NYSESymbolsPreprocessor(
exchange_repo, downloader.artifact_path, "processed.csv"
)
preprocessor.process()
self.stdout.success_message(
f"- Completed🐳. The artifact is saved to {preprocessor.artifact_path}"
os.path.join(self._save_dir, Exchange.NASDAQ)
)
self._run(exchange_repo, NYSESymbolsDataLoader, NYSESymbolsPreprocessor)

if self.exchange == Exchange.AMEX or self.exchange is None:
self.stdout.step_start(f"{Exchange.AMEX} symbols with nasdapapi")
exchange_repo = LocalArtifactRepository(
os.path.join(self.save_dir, Exchange.NASDAQ)
os.path.join(self._save_dir, Exchange.NASDAQ)
)
downloader = AMEXSymbolsDataLoader(exchange_repo, "raw_data.json")
if not self.skip_download:
self.stdout.normal_message("- Downloading ...")
downloader.download()
else:
self.stdout.warning_message("- Skip Downloading")
self._run(exchange_repo, AMEXSymbolsDataLoader, AMEXSymbolsPreprocessor)

preprocessor = AMEXSymbolsPreprocessor(
exchange_repo, downloader.artifact_path, "processed.csv"
)
preprocessor.process()
self.stdout.success_message(
f"- Completed🐳. The artifact is saved to {preprocessor.artifact_path}"
)
def _run(
self,
repository: ArtifactRepository,
data_loader: Any,
preprocessor: Any,
):
_data_loader = data_loader(repository)
if not self.skip_download:
self.stdout.normal_message("- Downloading ...")
_data_loader.download()
else:
self.stdout.warning_message("- Skip Downloading")
preprocessor(repository, _data_loader.artifact_path).process()
self.stdout.success_message(
f"- Completed🐳. The artifact is saved to {preprocessor.artifact_path}"
)
6 changes: 3 additions & 3 deletions stocklake/nasdaqapi/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(
self,
artifact_repo: ArtifactRepository,
source_artifact_path_json: str,
artifact_filename_csv: str = "data.csv",
artifact_filename_csv: str = "nasdaq_data.csv",
):
super().__init__(artifact_repo)
self.source_artifact_path_json = source_artifact_path_json
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(
self,
artifact_repo: ArtifactRepository,
source_artifact_path_json: str,
artifact_filename_csv: str = "data.csv",
artifact_filename_csv: str = "nyse_data.csv",
):
super().__init__(artifact_repo)
self.source_artifact_path_json = source_artifact_path_json
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
self,
artifact_repo: ArtifactRepository,
source_artifact_path_json: str,
artifact_filename_csv: str = "data.csv",
artifact_filename_csv: str = "amex_data.csv",
):
super().__init__(artifact_repo)
self.source_artifact_path_json = source_artifact_path_json
Expand Down
4 changes: 2 additions & 2 deletions stocklake/polygonapi/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from polygon import RESTClient
from polygon.rest.models.financials import Financials

from stocklake.data_loaders.base import DataLoader
from stocklake.data_loaders.base import BaseDataLoader
from stocklake.stores.artifact.base import ArtifactRepository

logger = logging.getLogger(__name__)


class PolygonFinancialsDataLoader(DataLoader):
class PolygonFinancialsDataLoader(BaseDataLoader):
def __init__(
self,
artifact_repo: ArtifactRepository,
Expand Down
35 changes: 35 additions & 0 deletions tests/nasdaqapi/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import os
from unittest import mock

import pytest

from stocklake.exceptions import StockLoaderException
from stocklake.nasdaqapi.constants import Exchange
from stocklake.nasdaqapi.pipeline import NASDAQSymbolsPipeline
from stocklake.stores.constants import StoreType
from tests.nasdaqapi.test_data_loader import mock_requests_get


def test_invalid_store_type_specified():
Expand All @@ -11,3 +17,32 @@ def test_invalid_store_type_specified():
str(exc.value)
== "Specified store type is invalid, INVALID_STORE_TYPE, valid types are ['local_artifact', 'postgresql']"
)


@mock.patch("requests.get", side_effect=mock_requests_get)
@pytest.mark.parametrize("exchange_name", Exchange.exchanges())
def test_run_each_symbols_with_local_artifact(mock_get, exchange_name, tmpdir):
pipeline = NASDAQSymbolsPipeline(
skip_download=False,
exchange=exchange_name,
store_type=StoreType.LOCAL_ARTIFACT,
data_dir=tmpdir,
)
pipeline.run()
data_dir = os.path.join(pipeline.save_dir_path, "nasdaq")
assert os.path.exists(os.path.join(data_dir, f"{exchange_name}_data.csv"))
assert os.path.exists(os.path.join(data_dir, f"raw_{exchange_name}_data.json"))


@mock.patch("requests.get", side_effect=mock_requests_get)
def test_run_with_local_artifact(mock_get, tmpdir):
pipeline = NASDAQSymbolsPipeline(
skip_download=False,
store_type=StoreType.LOCAL_ARTIFACT,
data_dir=tmpdir,
)
pipeline.run()
data_dir = os.path.join(pipeline.save_dir_path, "nasdaq")
for exchange_name in Exchange.exchanges():
assert os.path.exists(os.path.join(data_dir, f"{exchange_name}_data.csv"))
assert os.path.exists(os.path.join(data_dir, f"raw_{exchange_name}_data.json"))

0 comments on commit 7e03a83

Please sign in to comment.