Skip to content

Commit

Permalink
Bugfix/location order (#349)
Browse files Browse the repository at this point in the history
Bugfix/location order

Fixes bug in scrub_location function causes locations to not be mapped to the correct location IDs
- *Category*: Bugfix
- *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-5107

Changes and notes
-Fixes bug where locations order were out of sync with data and now maps location names to location IDs.
-Hotfix to pin numpy below 2.0

Testing
All tests pass
  • Loading branch information
albrja authored Jun 17, 2024
1 parent edb21c3 commit e77fe0b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**5.0.1 - 06/13/24**

- Fix bug in scrub_location

**5.0.0 - 05/20/24**

- Pull GBD 2021 data
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
long_description = f.read()

install_requirements = [
"numpy",
"numpy<2.0.0",
"scipy",
"pandas",
"click",
Expand Down
8 changes: 8 additions & 0 deletions src/vivarium_inputs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,12 @@ def get_relative_risk(
location_id: List[int],
years: Optional[Union[int, str, List[int]]] = None,
) -> pd.DataFrame:
if len(set(location_id)) > 1:
raise ValueError(
"Extracting relative risk only supports one location at a time. Provided "
f"{location_id}."
)

data = extract.extract_data(
entity,
"relative_risk",
Expand Down Expand Up @@ -523,6 +529,8 @@ def get_relative_risk(
data.loc[tmrel_mask, DRAW_COLUMNS] = data.loc[tmrel_mask, DRAW_COLUMNS].mask(
np.isclose(data.loc[tmrel_mask, DRAW_COLUMNS], 1.0), 1.0
)
# Coerce location_id from global to requested location - location_id is list of length 1
data["location_id"] = location_id[0]

return data

Expand Down
15 changes: 10 additions & 5 deletions src/vivarium_inputs/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,22 @@ def scrub_location(data: pd.DataFrame, location: Union[int, List[str]]) -> pd.Da
# Coerce location names
if not isinstance(location, list):
location = [location]
location = [
location_names = [
utility_data.get_location_name(loc) if isinstance(loc, int) else loc
for loc in location
]
location_dict = {
utility_data.get_location_id(loc_name): loc_name for loc_name in location_names
}

if "location_id" in data.index.names:
data.index = data.index.rename("location", level="location_id").set_levels(
location, level="location"
)
index_cols = data.index.names
data = data.reset_index()
data["location_id"] = data["location_id"].map(location_dict)
data = data.set_index(index_cols)
data.index = data.index.rename("location", level="location_id")
else:
data = pd.concat([data], keys=location, names=["location"])
data = pd.concat([data], keys=list(location_names), names=["location"])
return data


Expand Down

0 comments on commit e77fe0b

Please sign in to comment.