Skip to content

Commit

Permalink
Added normalization for predictions.
Browse files Browse the repository at this point in the history
  • Loading branch information
karannb committed Nov 10, 2024
1 parent 5afeb6a commit 90e47c9
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions aviary/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import wandb.apis.public
from torch.utils.data import DataLoader

from aviary.core import BaseModelClass
from aviary.core import BaseModelClass, Normalizer
from aviary.data import InMemoryDataLoader

__author__ = "Janosh Riebesell"
Expand Down Expand Up @@ -90,13 +90,20 @@ def make_ensemble_predictions(
model = model_cls(**model_params)
model.to(device)

model.load_state_dict(checkpoint["model_state"])
# some models save the state dict under a different key
state_dict_field = "model_state" if "model_state" in checkpoint else "state_dict"
model.load_state_dict(checkpoint[state_dict_field])

with torch.no_grad():
preds = np.concatenate(
[model(*inputs)[0].cpu().numpy() for inputs, *_ in data_loader]
).squeeze()

# denormalize predictions if a normalizer was used during training
if "normalizer_dict" in checkpoint:
normalizer = Normalizer.from_state_dict(checkpoint["normalizer_dict"][target_name])
preds = normalizer.denorm(preds)

pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}"

if model.robust:
Expand Down

0 comments on commit 90e47c9

Please sign in to comment.