diff --git a/aviary/predict.py b/aviary/predict.py index f50d1e2..fbe0ad8 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -15,7 +15,7 @@ import wandb.apis.public from torch.utils.data import DataLoader - from aviary.core import BaseModelClass + from aviary.core import BaseModelClass, Normalizer from aviary.data import InMemoryDataLoader __author__ = "Janosh Riebesell" @@ -90,13 +90,20 @@ def make_ensemble_predictions( model = model_cls(**model_params) model.to(device) - model.load_state_dict(checkpoint["model_state"]) + # some models save the state dict under a different key + state_dict_field = "model_state" if "model_state" in checkpoint else "state_dict" + model.load_state_dict(checkpoint[state_dict_field]) with torch.no_grad(): preds = np.concatenate( [model(*inputs)[0].cpu().numpy() for inputs, *_ in data_loader] ).squeeze() + # denormalize predictions if a normalizer was used during training + if "normalizer_dict" in checkpoint: + normalizer = Normalizer.from_state_dict(checkpoint["normalizer_dict"][target_name]) + preds = normalizer.denorm(preds) + pred_col = f"{target_col}_pred_{idx}" if target_col else f"pred_{idx}" if model.robust: