Skip to content

Commit d7f2822

Browse files
committed
fixed predict with probs error
1 parent 841fb4d commit d7f2822

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

madewithml/predict.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Dict, Iterable, List
33
from urllib.parse import urlparse
44

5+
import numpy as np
56
import pandas as pd
67
import ray
78
import torch
@@ -62,8 +63,6 @@ def predict_with_proba(
6263
"""
6364
preprocessor = predictor.get_preprocessor()
6465
z = predictor.predict(data=df)["predictions"]
65-
import numpy as np
66-
6766
y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy()
6867
results = []
6968
for i, prob in enumerate(y_prob):
@@ -130,7 +129,7 @@ def predict(
130129

131130
# Predict
132131
sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}])
133-
results = predict_with_proba(df=sample_df, predictor=predictor, index_to_class=preprocessor.index_to_class)
132+
results = predict_with_proba(df=sample_df, predictor=predictor)
134133
logger.info(json.dumps(results, cls=NumpyEncoder, indent=2))
135134
return results
136135

0 commit comments

Comments
 (0)