Skip to content

Commit

Permalink
cast to cpu before numpy()
Browse files Browse the repository at this point in the history
avoids error if trying to call .numpy() on tensor that is not on cpu
  • Loading branch information
sammlapp committed Nov 11, 2024
1 parent 028051a commit b8161d4
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions opensoundscape/ml/shallow_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,15 @@ def quick_fit(
)
try:
auroc = roc_auc_score(
validation_labels.detach().numpy(), val_outputs.detach().numpy()
validation_labels.detach().cpu().numpy(),
val_outputs.detach().cpu().numpy(),
)
except:
auroc = float("nan")
try:
map = average_precision_score(
validation_labels.detach().numpy(), val_outputs.detach().numpy()
validation_labels.detach().cpu().numpy(),
val_outputs.detach().cpu().numpy(),
)
except:
map = float("nan")
Expand Down

0 comments on commit b8161d4

Please sign in to comment.