Skip to content

Commit

Permalink
minor change in PyTorch Wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
JHoelli committed Aug 23, 2024
1 parent dcaf0e6 commit a6e7b00
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion TSInterpret/Models/PyTorchModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def predict(self, item) -> List:
else:
item = torch.from_numpy(item)
out = self.model(item.float())
y_pred = torch.nn.functional.softmax(out).detach().numpy()
if out.shape[-1]>1:
y_pred = torch.nn.functional.softmax(out).detach().numpy()
else:
y_pred=out.detach().numpy()
return y_pred

def load_model(self, path):
Expand Down

0 comments on commit a6e7b00

Please sign in to comment.