diff --git a/TSInterpret/Models/PyTorchModel.py b/TSInterpret/Models/PyTorchModel.py index 4bb647c..0514113 100644 --- a/TSInterpret/Models/PyTorchModel.py +++ b/TSInterpret/Models/PyTorchModel.py @@ -30,7 +30,10 @@ def predict(self, item) -> List: else: item = torch.from_numpy(item) out = self.model(item.float()) - y_pred = torch.nn.functional.softmax(out).detach().numpy() + if out.shape[-1]>1: + y_pred = torch.nn.functional.softmax(out).detach().numpy() + else: + y_pred=out.detach().numpy() return y_pred def load_model(self, path):