Skip to content

Commit

Permalink
Merge pull request #171 from tsugumi-sys/feature/use-entities-instead…
Browse files Browse the repository at this point in the history
…-of-schemas-nasdaqapi

Techdebt: use entities instead of scheams nasdaqapi
  • Loading branch information
tsugumi-sys authored Jul 7, 2024
2 parents 47282d0 + a3ea040 commit 1a8ac72
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 107 deletions.
6 changes: 3 additions & 3 deletions stocklake/nasdaqapi/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from stocklake.core.base_data_loader import BaseDataLoader
from stocklake.core.constants import CACHE_DIR
from stocklake.nasdaqapi.constants import Exchange
from stocklake.nasdaqapi.entities import RawNasdaqApiSymbolData
from stocklake.nasdaqapi.entities import RawNasdaqApiData
from stocklake.nasdaqapi.utils import nasdaq_api_get_request
from stocklake.stores.artifact.local_artifact_repo import LocalArtifactRepository

Expand All @@ -29,7 +29,7 @@ def cache_artifact_path(self) -> str:
self._cache_artifact_repo.artifact_dir, self._cache_artifact_filename
)

def download(self) -> List[RawNasdaqApiSymbolData]:
def download(self) -> List[RawNasdaqApiData]:
logger.info(
f"Loading {self.exchange_name.upper()} symbols data from `https://www.nasdaq.com/`"
)
Expand All @@ -40,4 +40,4 @@ def download(self) -> List[RawNasdaqApiSymbolData]:
with open(local_file, "w") as f:
json.dump(data, f)
self._cache_artifact_repo.save_artifact(local_file)
return data
return [RawNasdaqApiData(**d) for d in data]
42 changes: 38 additions & 4 deletions stocklake/nasdaqapi/entities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
from typing import List, Optional, TypedDict

from pydantic import BaseModel, ConfigDict

class RawNasdaqApiSymbolData(TypedDict):

class TD_RawNasdaqApiData(TypedDict):
symbol: str
name: str
lastsale: str
netchange: str
pctchange: str
volume: str
marketCap: str
country: str
ipoyear: str
industry: str
sector: str
url: str


class RawNasdaqApiData(BaseModel):
model_config = ConfigDict(extra="forbid")
symbol: str
name: str
lastsale: str
Expand All @@ -16,7 +34,7 @@ class RawNasdaqApiSymbolData(TypedDict):
url: str


class NasdaqApiSymbolData(TypedDict):
class NasdaqApiDataBase(BaseModel):
symbol: str
exchange: str
name: str
Expand All @@ -32,10 +50,26 @@ class NasdaqApiSymbolData(TypedDict):
url: str


class PreprocessedNasdaqApiData(NasdaqApiDataBase):
pass


class NasdaqApiDataCreate(NasdaqApiDataBase):
pass


class NasdaqApiData(NasdaqApiDataBase):
model_config = ConfigDict(from_attributes=True)

id: int
created_at: int
updated_at: int


class _ResponseData(TypedDict):
asOf: str
headers: RawNasdaqApiSymbolData
rows: List[RawNasdaqApiSymbolData]
headers: TD_RawNasdaqApiData
rows: List[TD_RawNasdaqApiData]


class NasdaqAPIResponse(TypedDict):
Expand Down
51 changes: 24 additions & 27 deletions stocklake/nasdaqapi/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,44 @@

from stocklake.core.base_preprocessor import BasePreprocessor
from stocklake.nasdaqapi.constants import Exchange
from stocklake.nasdaqapi.entities import NasdaqApiSymbolData, RawNasdaqApiSymbolData
from stocklake.nasdaqapi.entities import PreprocessedNasdaqApiData, RawNasdaqApiData


class NASDAQSymbolsPreprocessor(BasePreprocessor):
def process(
self, exchange: Exchange, data: List[RawNasdaqApiSymbolData]
) -> List[NasdaqApiSymbolData]:
processed_data: List[NasdaqApiSymbolData] = []
for data_dic in data:
_data: NasdaqApiSymbolData = {
"symbol": data_dic["symbol"],
self, exchange: Exchange, data: List[RawNasdaqApiData]
) -> List[PreprocessedNasdaqApiData]:
processed_data = []
for d in data:
_data = {
"symbol": d.symbol,
"exchange": exchange,
"name": data_dic["name"],
"last_sale": float(
data_dic["lastsale"].replace("$", "").replace(",", "")
),
"name": d.name,
"last_sale": float(d.lastsale.replace("$", "").replace(",", "")),
"pct_change": (
None
if data_dic["pctchange"] == ""
else float(data_dic["pctchange"].replace("%", ""))
None if d.pctchange == "" else float(d.pctchange.replace("%", ""))
),
"net_change": float(data_dic["netchange"]),
"volume": float(data_dic["volume"]),
"marketcap": self._market_cap(data_dic),
"country": data_dic["country"],
"ipo_year": self._ipo_year(data_dic),
"industry": data_dic["industry"],
"sector": data_dic["sector"],
"url": data_dic["url"],
"net_change": float(d.netchange),
"volume": float(d.volume),
"marketcap": self._market_cap(d),
"country": d.country,
"ipo_year": self._ipo_year(d),
"industry": d.industry,
"sector": d.sector,
"url": d.url,
}
processed_data.append(_data)
# NOTE: We ignore arg-type mypy error here, because of this bug https://github.com/python/mypy/issues/5382.
processed_data.append(PreprocessedNasdaqApiData(**_data)) # type: ignore
return processed_data

def _ipo_year(self, data_dic: RawNasdaqApiSymbolData) -> int:
ipo_year = data_dic["ipoyear"]
def _ipo_year(self, data: RawNasdaqApiData) -> int:
ipo_year = data.ipoyear
if ipo_year == "":
return 0
return int(ipo_year)

def _market_cap(self, data_dic: RawNasdaqApiSymbolData) -> float:
market_cap = data_dic["marketCap"].replace(",", "")
def _market_cap(self, data: RawNasdaqApiData) -> float:
market_cap = data.marketCap.replace(",", "")
if market_cap == "":
return 0.0
return float(market_cap)
16 changes: 10 additions & 6 deletions stocklake/nasdaqapi/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from stocklake.core.base_store import BaseStore
from stocklake.core.constants import DATA_DIR
from stocklake.exceptions import StockLakeException
from stocklake.nasdaqapi import entities
from stocklake.nasdaqapi.constants import Exchange
from stocklake.nasdaqapi.entities import NasdaqApiSymbolData
from stocklake.stores.artifact.local_artifact_repo import LocalArtifactRepository
from stocklake.stores.constants import StoreType
from stocklake.stores.db import models, schemas
from stocklake.stores.db import models
from stocklake.stores.db.database import (
DATABASE_SESSION_TYPE,
local_session,
Expand All @@ -31,21 +31,23 @@ def save(
self,
store_type: StoreType,
exchange: Exchange,
data: List[NasdaqApiSymbolData],
data: List[entities.PreprocessedNasdaqApiData],
) -> str:
if store_type == StoreType.LOCAL_ARTIFACT:
repository = LocalArtifactRepository(SAVE_ARTIFACTS_DIR)
with tempfile.TemporaryDirectory() as tmpdir:
csv_file_path = os.path.join(tmpdir, f"{exchange}_data.csv")
save_data_to_csv(data, csv_file_path)
save_data_to_csv([d.model_dump() for d in data], csv_file_path)
repository.save_artifact(csv_file_path)
return repository.list_artifacts()[0].path
elif store_type == StoreType.POSTGRESQL:
if self.sqlalchemy_session is None:
raise StockLakeException("`sqlalchemy_session` is None.")
sqlstore = NasdaqApiSQLAlchemyStore(exchange, self.sqlalchemy_session)
sqlstore.delete()
sqlstore.create([schemas.NasdaqStockCreate(**d) for d in data])
sqlstore.create(
[entities.NasdaqApiDataCreate(**d.model_dump()) for d in data]
)
return os.path.join(
safe_database_url_from_sessionmaker(self.sqlalchemy_session),
models.NasdaqApiData.__tablename__,
Expand All @@ -59,7 +61,9 @@ def __init__(self, exchange: Exchange, session: DATABASE_SESSION_TYPE):
self.exchange = exchange
self.session = session

def create(self, data: schemas.NasdaqStockCreate | List[schemas.NasdaqStockCreate]):
def create(
self, data: entities.NasdaqApiDataCreate | List[entities.NasdaqApiDataCreate]
):
with self.session() as session, session.begin():
if isinstance(data, list):
session.add_all([models.NasdaqApiData(**d.model_dump()) for d in data])
Expand Down
28 changes: 0 additions & 28 deletions stocklake/stores/db/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,6 @@
from pydantic import BaseModel, ConfigDict


class NasdaqStockBase(BaseModel):
symbol: str
exchange: str
name: str
last_sale: float
pct_change: Optional[float]
net_change: float
volume: float
marketcap: float
country: str
ipo_year: int
industry: str
sector: str
url: str


class NasdaqStockCreate(NasdaqStockBase):
pass


class NasdaqStock(NasdaqStockBase):
model_config = ConfigDict(from_attributes=True)

id: int
created_at: int
updated_at: int


class PolygonFinancialsDataBase(BaseModel):
# financials data
# - balance sheet
Expand Down
6 changes: 3 additions & 3 deletions tests/nasdaqapi/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def MockNasdaqAPIServer():


@pytest.mark.parametrize("exchange_name", Exchange.exchanges())
def test_data_loader(exchange_name, tmpdir, MockNasdaqAPIServer):
def test_download(exchange_name, tmpdir, MockNasdaqAPIServer):
data_loader = NASDAQSymbolsDataLoader(exchange_name=exchange_name, cache_dir=tmpdir)
data = data_loader.download()
assert os.path.exists(data_loader.cache_artifact_path)
for row in data:
for d in data:
for col in expected_cols:
assert col in row
assert col in d.model_dump()
4 changes: 2 additions & 2 deletions tests/nasdaqapi/test_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def test_process(tmpdir, MockNasdaqAPIServer): # noqa: F811
data_loader = NASDAQSymbolsDataLoader(exchange_name=Exchange.AMEX, cache_dir=tmpdir)
preprocessor = NASDAQSymbolsPreprocessor()
data = preprocessor.process(exchange=Exchange.NASDAQ, data=data_loader.download())
for data_dic in data:
for key, val in data_dic.items():
for d in data:
for key, val in d.model_dump().items():
if key in ["last_sale", "net_change", "pct_change", "marketcap", "volume"]:
assert isinstance(val, float)
elif key in ["ipo_year"]:
Expand Down
72 changes: 38 additions & 34 deletions tests/nasdaqapi/test_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from conftest import SessionLocal # noqa: F401
from stocklake.nasdaqapi.constants import Exchange
from stocklake.nasdaqapi.entities import NasdaqApiDataCreate
from stocklake.nasdaqapi.stores import (
SAVE_ARTIFACTS_DIR,
NasdaqApiSQLAlchemyStore,
NASDAQDataStore,
)
from stocklake.stores.constants import StoreType
from stocklake.stores.db.models import NasdaqApiData
from stocklake.stores.db.schemas import NasdaqStockCreate


def test_nasdaqdatastore_local_artifact():
Expand All @@ -22,21 +22,23 @@ def test_nasdaqdatastore_local_artifact():
StoreType.LOCAL_ARTIFACT,
exchange_name,
[
{
"symbol": "TEST",
"exchange": Exchange.NASDAQ,
"name": "Test Company",
"last_sale": 0.88,
"pct_change": 0.5,
"net_change": 0.35,
"volume": 100.5,
"marketcap": 0.75,
"country": "US",
"ipo_year": 1999,
"industry": "Tech",
"sector": "Health",
"url": "https://example.com",
}
NasdaqApiDataCreate(
**{
"symbol": "TEST",
"exchange": Exchange.NASDAQ,
"name": "Test Company",
"last_sale": 0.88,
"pct_change": 0.5,
"net_change": 0.35,
"volume": 100.5,
"marketcap": 0.75,
"country": "US",
"ipo_year": 1999,
"industry": "Tech",
"sector": "Health",
"url": "https://example.com",
}
)
],
)
assert os.path.exists(os.path.join(SAVE_ARTIFACTS_DIR, f"{exchange_name}_data.csv"))
Expand All @@ -49,21 +51,23 @@ def test_nasdaqdatastore_postgresql(SessionLocal): # noqa: F811
StoreType.POSTGRESQL,
exchange_name,
[
{
"symbol": "TEST",
"exchange": Exchange.NASDAQ,
"name": "Test Company",
"last_sale": 0.88,
"pct_change": 0.5,
"net_change": 0.35,
"volume": 100.5,
"marketcap": 0.75,
"country": "US",
"ipo_year": 1999,
"industry": "Tech",
"sector": "Health",
"url": "https://example.com",
}
NasdaqApiDataCreate(
**{
"symbol": "TEST",
"exchange": Exchange.NASDAQ,
"name": "Test Company",
"last_sale": 0.88,
"pct_change": 0.5,
"net_change": 0.35,
"volume": 100.5,
"marketcap": 0.75,
"country": "US",
"ipo_year": 1999,
"industry": "Tech",
"sector": "Health",
"url": "https://example.com",
}
)
],
)
with SessionLocal() as session, session.begin():
Expand Down Expand Up @@ -91,7 +95,7 @@ def test_NasdaqAPISQLAlchemyStore_create(SessionLocal): # noqa: F811
}

# Add item
store.create(NasdaqStockCreate(**data))
store.create(NasdaqApiDataCreate(**data))
with SessionLocal() as session, session.begin():
res = session.query(NasdaqApiData).all()
assert len(res) == 1
Expand All @@ -117,7 +121,7 @@ def test_NasdaqAPISQLAlchemyStore_create(SessionLocal): # noqa: F811
data2["symbol"] = "TEST2"
data3 = copy.deepcopy(data)
data3["symbol"] = "TEST3"
store.create([NasdaqStockCreate(**data2), NasdaqStockCreate(**data3)])
store.create([NasdaqApiDataCreate(**data2), NasdaqApiDataCreate(**data3)])
with SessionLocal() as session, session.begin():
assert len(session.query(NasdaqApiData).all()) == 3

Expand All @@ -142,7 +146,7 @@ def test_NasdaqAPISQLAlchemyStore_delete(exchange, SessionLocal): # noqa: F811
}

# Add item
store.create(NasdaqStockCreate(**data))
store.create(NasdaqApiDataCreate(**data))
with SessionLocal() as session, session.begin():
res = session.query(NasdaqApiData).all()
assert len(res) == 1
Expand Down

0 comments on commit 1a8ac72

Please sign in to comment.