Skip to content

Commit 7004de7

Browse files
authored
Merge pull request #237 from zero-sum-seattle/refactor/pydantic-model-migration
Refactor/pydantic model migration
2 parents 466589f + 99b4778 commit 7004de7

File tree

104 files changed

+5411
-5570
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+5411
-5570
lines changed

README.md

Lines changed: 214 additions & 324 deletions
Large diffs are not rendered by default.

mlbstatsapi/mlb_api.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from mlbstatsapi.models.seasons import Season
1717
from mlbstatsapi.models.drafts import Round
1818
from mlbstatsapi.models.awards import Award
19-
from mlbstatsapi.models.gamepace import Gamepace
20-
from mlbstatsapi.models.homerunderby import Homerunderby
19+
from mlbstatsapi.models.gamepace import GamePace
20+
from mlbstatsapi.models.homerunderby import HomeRunDerby
2121
from mlbstatsapi.models.standings import Standings
2222

2323
from .mlb_dataadapter import MlbDataAdapter
@@ -1099,7 +1099,7 @@ def get_game_ids(self, date: str = None,
10991099

11001100
return game_ids
11011101

1102-
def get_gamepace(self, season: str, sport_id=1, **params) -> Union[Gamepace, None]:
1102+
def get_gamepace(self, season: str, sport_id=1, **params) -> Union[GamePace, None]:
11031103
"""
11041104
Get pace of game metrics for specific sport, league or team.
11051105
@@ -1164,7 +1164,7 @@ def get_gamepace(self, season: str, sport_id=1, **params) -> Union[Gamepace, Non
11641164
or 'leagues' in mlb_data.data and mlb_data.data['leagues']
11651165
or 'sports' in mlb_data.data and mlb_data.data['sports']):
11661166

1167-
return Gamepace(**mlb_data.data)
1167+
return GamePace(**mlb_data.data)
11681168

11691169
def get_venue(self, venue_id: int, **params) -> Union[Venue, None]:
11701170
"""
@@ -2016,7 +2016,7 @@ def get_awards(self, award_id: str, **params) -> List[Award]:
20162016

20172017
return awards_list
20182018

2019-
def get_homerun_derby(self, game_id, **params) -> Union[Homerunderby, None]:
2019+
def get_homerun_derby(self, game_id, **params) -> Union[HomeRunDerby, None]:
20202020
"""
20212021
The homerun derby endpoint on the Stats API allows for users to
20222022
request information from the MLB database pertaining to the
@@ -2036,7 +2036,7 @@ def get_homerun_derby(self, game_id, **params) -> Union[Homerunderby, None]:
20362036
20372037
Returns
20382038
-------
2039-
Homerunderby object
2039+
HomeRunDerby object
20402040
20412041
See Also
20422042
--------
@@ -2049,7 +2049,7 @@ def get_homerun_derby(self, game_id, **params) -> Union[Homerunderby, None]:
20492049
None
20502050

20512051
if 'status' in mlb_data.data and mlb_data.data['status']:
2052-
return Homerunderby(**mlb_data.data)
2052+
return HomeRunDerby(**mlb_data.data)
20532053

20542054

20552055
def get_team_stats(self, team_id: int, stats: list, groups: list, **params) -> dict:

mlbstatsapi/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .base import MLBBaseModel
2+
3+
__all__ = ["MLBBaseModel"]
Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
1-
from dataclasses import dataclass, field
2-
from typing import Union, List
1+
from typing import List
2+
from pydantic import Field
3+
from mlbstatsapi.models.base import MLBBaseModel
34
from .attributes import AttendanceTotals, AttendanceRecords
45

5-
@dataclass
6-
class Attendance:
6+
7+
class Attendance(MLBBaseModel):
78
"""
89
A class to represent attendance.
10+
911
Attributes
1012
----------
11-
copyright : str
12-
Copyright message
1313
records : List[AttendanceRecords]
14-
List of attendance records
15-
aggregatetotals : AttendanceAggregateTotals
16-
Attendence aggregate total numbers for query
14+
List of attendance records.
15+
aggregate_totals : AttendanceTotals
16+
Attendance aggregate total numbers for query.
1717
"""
18-
aggregatetotals: Union[AttendanceTotals, dict]
19-
records: Union[List[AttendanceRecords], List[dict]] = field(default_factory=list)
20-
21-
def __post_init__(self):
22-
self.records = [AttendanceRecords(**record) for record in self.records if self.records]
23-
self.aggregatetotals = AttendanceTotals(**self.aggregatetotals)
18+
aggregate_totals: AttendanceTotals = Field(alias="aggregatetotals")
19+
records: List[AttendanceRecords] = []

0 commit comments

Comments
 (0)