diff --git a/stocklake/nasdaqapi/data_loader.py b/stocklake/nasdaqapi/data_loader.py index 6fd97af..f2c3487 100644 --- a/stocklake/nasdaqapi/data_loader.py +++ b/stocklake/nasdaqapi/data_loader.py @@ -8,7 +8,7 @@ from stocklake.core.base_data_loader import BaseDataLoader from stocklake.core.constants import CACHE_DIR from stocklake.nasdaqapi.constants import Exchange -from stocklake.nasdaqapi.entities import RawNasdaqApiSymbolData +from stocklake.nasdaqapi.entities import RawNasdaqApiData from stocklake.nasdaqapi.utils import nasdaq_api_get_request from stocklake.stores.artifact.local_artifact_repo import LocalArtifactRepository @@ -29,7 +29,7 @@ def cache_artifact_path(self) -> str: self._cache_artifact_repo.artifact_dir, self._cache_artifact_filename ) - def download(self) -> List[RawNasdaqApiSymbolData]: + def download(self) -> List[RawNasdaqApiData]: logger.info( f"Loading {self.exchange_name.upper()} symbols data from `https://www.nasdaq.com/`" ) @@ -40,4 +40,4 @@ def download(self) -> List[RawNasdaqApiSymbolData]: with open(local_file, "w") as f: json.dump(data, f) self._cache_artifact_repo.save_artifact(local_file) - return data + return [RawNasdaqApiData(**d) for d in data] diff --git a/stocklake/nasdaqapi/entities.py b/stocklake/nasdaqapi/entities.py index 94bbcd6..f770f72 100644 --- a/stocklake/nasdaqapi/entities.py +++ b/stocklake/nasdaqapi/entities.py @@ -1,7 +1,25 @@ from typing import List, Optional, TypedDict +from pydantic import BaseModel, ConfigDict -class RawNasdaqApiSymbolData(TypedDict): + +class TD_RawNasdaqApiData(TypedDict): + symbol: str + name: str + lastsale: str + netchange: str + pctchange: str + volume: str + marketCap: str + country: str + ipoyear: str + industry: str + sector: str + url: str + + +class RawNasdaqApiData(BaseModel): + model_config = ConfigDict(extra="forbid") symbol: str name: str lastsale: str @@ -16,7 +34,7 @@ class RawNasdaqApiSymbolData(TypedDict): url: str -class NasdaqApiSymbolData(TypedDict): +class NasdaqApiDataBase(BaseModel): symbol: str exchange: str name: str @@ -32,10 +50,26 @@ class NasdaqApiSymbolData(TypedDict): url: str +class PreprocessedNasdaqApiData(NasdaqApiDataBase): + pass + + +class NasdaqApiDataCreate(NasdaqApiDataBase): + pass + + +class NasdaqApiData(NasdaqApiDataBase): + model_config = ConfigDict(from_attributes=True) + + id: int + created_at: int + updated_at: int + + class _ResponseData(TypedDict): asOf: str - headers: RawNasdaqApiSymbolData - rows: List[RawNasdaqApiSymbolData] + headers: TD_RawNasdaqApiData + rows: List[TD_RawNasdaqApiData] class NasdaqAPIResponse(TypedDict): diff --git a/stocklake/nasdaqapi/preprocessor.py b/stocklake/nasdaqapi/preprocessor.py index 9227c80..f914e04 100644 --- a/stocklake/nasdaqapi/preprocessor.py +++ b/stocklake/nasdaqapi/preprocessor.py @@ -2,47 +2,44 @@ from stocklake.core.base_preprocessor import BasePreprocessor from stocklake.nasdaqapi.constants import Exchange -from stocklake.nasdaqapi.entities import NasdaqApiSymbolData, RawNasdaqApiSymbolData +from stocklake.nasdaqapi.entities import PreprocessedNasdaqApiData, RawNasdaqApiData class NASDAQSymbolsPreprocessor(BasePreprocessor): def process( - self, exchange: Exchange, data: List[RawNasdaqApiSymbolData] - ) -> List[NasdaqApiSymbolData]: - processed_data: List[NasdaqApiSymbolData] = [] - for data_dic in data: - _data: NasdaqApiSymbolData = { - "symbol": data_dic["symbol"], + self, exchange: Exchange, data: List[RawNasdaqApiData] + ) -> List[PreprocessedNasdaqApiData]: + processed_data = [] + for d in data: + _data = { + "symbol": d.symbol, "exchange": exchange, - "name": data_dic["name"], - "last_sale": float( - data_dic["lastsale"].replace("$", "").replace(",", "") - ), + "name": d.name, + "last_sale": float(d.lastsale.replace("$", "").replace(",", "")), "pct_change": ( - None - if data_dic["pctchange"] == "" - else float(data_dic["pctchange"].replace("%", "")) + None if d.pctchange == "" else float(d.pctchange.replace("%", "")) ), - "net_change": float(data_dic["netchange"]), - "volume": float(data_dic["volume"]), - "marketcap": self._market_cap(data_dic), - "country": data_dic["country"], - "ipo_year": self._ipo_year(data_dic), - "industry": data_dic["industry"], - "sector": data_dic["sector"], - "url": data_dic["url"], + "net_change": float(d.netchange), + "volume": float(d.volume), + "marketcap": self._market_cap(d), + "country": d.country, + "ipo_year": self._ipo_year(d), + "industry": d.industry, + "sector": d.sector, + "url": d.url, } - processed_data.append(_data) + # NOTE: We ignore arg-type mypy error here, because of this bug https://github.com/python/mypy/issues/5382. + processed_data.append(PreprocessedNasdaqApiData(**_data)) # type: ignore return processed_data - def _ipo_year(self, data_dic: RawNasdaqApiSymbolData) -> int: - ipo_year = data_dic["ipoyear"] + def _ipo_year(self, data: RawNasdaqApiData) -> int: + ipo_year = data.ipoyear if ipo_year == "": return 0 return int(ipo_year) - def _market_cap(self, data_dic: RawNasdaqApiSymbolData) -> float: - market_cap = data_dic["marketCap"].replace(",", "") + def _market_cap(self, data: RawNasdaqApiData) -> float: + market_cap = data.marketCap.replace(",", "") if market_cap == "": return 0.0 return float(market_cap) diff --git a/stocklake/nasdaqapi/stores.py b/stocklake/nasdaqapi/stores.py index 0b91e70..6041846 100644 --- a/stocklake/nasdaqapi/stores.py +++ b/stocklake/nasdaqapi/stores.py @@ -6,11 +6,11 @@ from stocklake.core.base_store import BaseStore from stocklake.core.constants import DATA_DIR from stocklake.exceptions import StockLakeException +from stocklake.nasdaqapi import entities from stocklake.nasdaqapi.constants import Exchange -from stocklake.nasdaqapi.entities import NasdaqApiSymbolData from stocklake.stores.artifact.local_artifact_repo import LocalArtifactRepository from stocklake.stores.constants import StoreType -from stocklake.stores.db import models, schemas +from stocklake.stores.db import models from stocklake.stores.db.database import ( DATABASE_SESSION_TYPE, local_session, @@ -31,13 +31,13 @@ def save( self, store_type: StoreType, exchange: Exchange, - data: List[NasdaqApiSymbolData], + data: List[entities.PreprocessedNasdaqApiData], ) -> str: if store_type == StoreType.LOCAL_ARTIFACT: repository = LocalArtifactRepository(SAVE_ARTIFACTS_DIR) with tempfile.TemporaryDirectory() as tmpdir: csv_file_path = os.path.join(tmpdir, f"{exchange}_data.csv") - save_data_to_csv(data, csv_file_path) + save_data_to_csv([d.model_dump() for d in data], csv_file_path) repository.save_artifact(csv_file_path) return repository.list_artifacts()[0].path elif store_type == StoreType.POSTGRESQL: @@ -45,7 +45,9 @@ def save( raise StockLakeException("`sqlalchemy_session` is None.") sqlstore = NasdaqApiSQLAlchemyStore(exchange, self.sqlalchemy_session) sqlstore.delete() - sqlstore.create([schemas.NasdaqStockCreate(**d) for d in data]) + sqlstore.create( + [entities.NasdaqApiDataCreate(**d.model_dump()) for d in data] + ) return os.path.join( safe_database_url_from_sessionmaker(self.sqlalchemy_session), models.NasdaqApiData.__tablename__, @@ -59,7 +61,9 @@ def __init__(self, exchange: Exchange, session: DATABASE_SESSION_TYPE): self.exchange = exchange self.session = session - def create(self, data: schemas.NasdaqStockCreate | List[schemas.NasdaqStockCreate]): + def create( + self, data: entities.NasdaqApiDataCreate | List[entities.NasdaqApiDataCreate] + ): with self.session() as session, session.begin(): if isinstance(data, list): session.add_all([models.NasdaqApiData(**d.model_dump()) for d in data]) diff --git a/stocklake/stores/db/schemas.py b/stocklake/stores/db/schemas.py index 8f1ba7a..670ff6f 100644 --- a/stocklake/stores/db/schemas.py +++ b/stocklake/stores/db/schemas.py @@ -3,34 +3,6 @@ from pydantic import BaseModel, ConfigDict -class NasdaqStockBase(BaseModel): - symbol: str - exchange: str - name: str - last_sale: float - pct_change: Optional[float] - net_change: float - volume: float - marketcap: float - country: str - ipo_year: int - industry: str - sector: str - url: str - - -class NasdaqStockCreate(NasdaqStockBase): - pass - - -class NasdaqStock(NasdaqStockBase): - model_config = ConfigDict(from_attributes=True) - - id: int - created_at: int - updated_at: int - - class PolygonFinancialsDataBase(BaseModel): # financials data # - balance sheet diff --git a/tests/nasdaqapi/test_data_loader.py b/tests/nasdaqapi/test_data_loader.py index cdd5e49..17685e1 100644 --- a/tests/nasdaqapi/test_data_loader.py +++ b/tests/nasdaqapi/test_data_loader.py @@ -31,10 +31,10 @@ def MockNasdaqAPIServer(): @pytest.mark.parametrize("exchange_name", Exchange.exchanges()) -def test_data_loader(exchange_name, tmpdir, MockNasdaqAPIServer): +def test_download(exchange_name, tmpdir, MockNasdaqAPIServer): data_loader = NASDAQSymbolsDataLoader(exchange_name=exchange_name, cache_dir=tmpdir) data = data_loader.download() assert os.path.exists(data_loader.cache_artifact_path) - for row in data: + for d in data: for col in expected_cols: - assert col in row + assert col in d.model_dump() diff --git a/tests/nasdaqapi/test_preprocessor.py b/tests/nasdaqapi/test_preprocessor.py index d49d947..6d32e4f 100644 --- a/tests/nasdaqapi/test_preprocessor.py +++ b/tests/nasdaqapi/test_preprocessor.py @@ -8,8 +8,8 @@ def test_process(tmpdir, MockNasdaqAPIServer): # noqa: F811 data_loader = NASDAQSymbolsDataLoader(exchange_name=Exchange.AMEX, cache_dir=tmpdir) preprocessor = NASDAQSymbolsPreprocessor() data = preprocessor.process(exchange=Exchange.NASDAQ, data=data_loader.download()) - for data_dic in data: - for key, val in data_dic.items(): + for d in data: + for key, val in d.model_dump().items(): if key in ["last_sale", "net_change", "pct_change", "marketcap", "volume"]: assert isinstance(val, float) elif key in ["ipo_year"]: diff --git a/tests/nasdaqapi/test_stores.py b/tests/nasdaqapi/test_stores.py index e545a2b..f9d6881 100644 --- a/tests/nasdaqapi/test_stores.py +++ b/tests/nasdaqapi/test_stores.py @@ -5,6 +5,7 @@ from conftest import SessionLocal # noqa: F401 from stocklake.nasdaqapi.constants import Exchange +from stocklake.nasdaqapi.entities import NasdaqApiDataCreate from stocklake.nasdaqapi.stores import ( SAVE_ARTIFACTS_DIR, NasdaqApiSQLAlchemyStore, @@ -12,7 +13,6 @@ ) from stocklake.stores.constants import StoreType from stocklake.stores.db.models import NasdaqApiData -from stocklake.stores.db.schemas import NasdaqStockCreate def test_nasdaqdatastore_local_artifact(): @@ -22,21 +22,23 @@ def test_nasdaqdatastore_local_artifact(): StoreType.LOCAL_ARTIFACT, exchange_name, [ - { - "symbol": "TEST", - "exchange": Exchange.NASDAQ, - "name": "Test Company", - "last_sale": 0.88, - "pct_change": 0.5, - "net_change": 0.35, - "volume": 100.5, - "marketcap": 0.75, - "country": "US", - "ipo_year": 1999, - "industry": "Tech", - "sector": "Health", - "url": "https://example.com", - } + NasdaqApiDataCreate( + **{ + "symbol": "TEST", + "exchange": Exchange.NASDAQ, + "name": "Test Company", + "last_sale": 0.88, + "pct_change": 0.5, + "net_change": 0.35, + "volume": 100.5, + "marketcap": 0.75, + "country": "US", + "ipo_year": 1999, + "industry": "Tech", + "sector": "Health", + "url": "https://example.com", + } + ) ], ) assert os.path.exists(os.path.join(SAVE_ARTIFACTS_DIR, f"{exchange_name}_data.csv")) @@ -49,21 +51,23 @@ def test_nasdaqdatastore_postgresql(SessionLocal): # noqa: F811 StoreType.POSTGRESQL, exchange_name, [ - { - "symbol": "TEST", - "exchange": Exchange.NASDAQ, - "name": "Test Company", - "last_sale": 0.88, - "pct_change": 0.5, - "net_change": 0.35, - "volume": 100.5, - "marketcap": 0.75, - "country": "US", - "ipo_year": 1999, - "industry": "Tech", - "sector": "Health", - "url": "https://example.com", - } + NasdaqApiDataCreate( + **{ + "symbol": "TEST", + "exchange": Exchange.NASDAQ, + "name": "Test Company", + "last_sale": 0.88, + "pct_change": 0.5, + "net_change": 0.35, + "volume": 100.5, + "marketcap": 0.75, + "country": "US", + "ipo_year": 1999, + "industry": "Tech", + "sector": "Health", + "url": "https://example.com", + } + ) ], ) with SessionLocal() as session, session.begin(): @@ -91,7 +95,7 @@ def test_NasdaqAPISQLAlchemyStore_create(SessionLocal): # noqa: F811 } # Add item - store.create(NasdaqStockCreate(**data)) + store.create(NasdaqApiDataCreate(**data)) with SessionLocal() as session, session.begin(): res = session.query(NasdaqApiData).all() assert len(res) == 1 @@ -117,7 +121,7 @@ def test_NasdaqAPISQLAlchemyStore_create(SessionLocal): # noqa: F811 data2["symbol"] = "TEST2" data3 = copy.deepcopy(data) data3["symbol"] = "TEST3" - store.create([NasdaqStockCreate(**data2), NasdaqStockCreate(**data3)]) + store.create([NasdaqApiDataCreate(**data2), NasdaqApiDataCreate(**data3)]) with SessionLocal() as session, session.begin(): assert len(session.query(NasdaqApiData).all()) == 3 @@ -142,7 +146,7 @@ def test_NasdaqAPISQLAlchemyStore_delete(exchange, SessionLocal): # noqa: F811 } # Add item - store.create(NasdaqStockCreate(**data)) + store.create(NasdaqApiDataCreate(**data)) with SessionLocal() as session, session.begin(): res = session.query(NasdaqApiData).all() assert len(res) == 1