Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Nov 19, 2024
1 parent efbc168 commit d9570b3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
27 changes: 18 additions & 9 deletions stemflow/model/AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,16 @@ def predict_proba(
warnings.warn(f"There are {nan_frac}% points ({nan_count} points) falling out of predictable range.")

if return_std:
return np.array([1-new_res["pred_mean"].values.flatten(), new_res["pred_mean"].values.flatten()]).T, new_res["pred_std"].values
if self.task=='classification':
return np.array([1-new_res["pred_mean"].values.flatten(), new_res["pred_mean"].values.flatten()]).T, new_res["pred_std"].values
else:
return new_res["pred_mean"].values.reshape(-1,1), new_res["pred_std"].values
else:
return np.array([1-new_res["pred_mean"].values.flatten(), new_res["pred_mean"].values.flatten()]).T
if self.task=='classification':
return np.array([1-new_res["pred_mean"].values.flatten(), new_res["pred_mean"].values.flatten()]).T
else:
return new_res["pred_mean"].values.reshape(-1,1)


@abstractmethod
def predict(
Expand Down Expand Up @@ -1406,7 +1413,9 @@ def predict(
predicted results. (pred_mean, pred_std) if return_std==true, and pred_mean if return_std==False.
"""

if return_by_separate_ensembles!=False:
raise AttributeError('If you want to return by separate ensembles in this classifier, use it in .predict_proba, instead of .predict.')

if return_std:
mean, std = self.predict_proba(
X_test,
Expand All @@ -1420,7 +1429,8 @@ def predict(
mean = mean[:,1]
mean = np.where(mean < cls_threshold, 0, mean)
mean = np.where(mean >= cls_threshold, 1, mean)
return mean, std
warnings.warn('This is a classification task. The standard deviation of the prediction is output at logit scale! The mean prediction is output at probability scale.')
return mean, std # notice! the std
else:
mean = self.predict_proba(
X_test,
Expand Down Expand Up @@ -1588,8 +1598,7 @@ def predict(
**base_model_prediction_param
)

if return_by_separate_ensembles:
return prediciton
else:
return prediciton[:,1] # the prediciton[:,0] won't make any sense -- it is a regressor, is should not have a method called "predict_proba". But we have it for completeness. So we should definately remove the first column when the prediction is done.

# if return_by_separate_ensembles, this will be the dataframe for ensemble
# if return_std, this wil be a tuple of mean and std of prediction
# if none of these, then it ill output the mean prediction
return prediciton
13 changes: 12 additions & 1 deletion stemflow/model/static_func_AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def train_one_stixel(
sample_weights = class_weight.compute_sample_weight(
class_weight="balanced", y=np.where(sub_y_train > 0, 1, 0)
).astype('float32')
class_weights = class_weight.compute_class_weight(
class_weight="balanced", classes=np.array([0,1]), y=np.where(sub_y_train > 0, 1, 0)
).astype('float32')
trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train, sample_weight=sample_weights)
trained_model.my_class_weights = class_weights

# try:
# trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train, sample_weight=sample_weights)
Expand Down Expand Up @@ -466,7 +470,14 @@ def predict_one_stixel(
if task == "regression":
pred = model_x_names_tuple[0].predict(X_test_stixel[model_x_names_tuple[1]])
else:
pred = model_x_names_tuple[0].predict_proba(X_test_stixel[model_x_names_tuple[1]], **base_model_prediction_param)[:, 1]
pred = model_x_names_tuple[0].predict_proba(X_test_stixel[model_x_names_tuple[1]], **base_model_prediction_param)
if hasattr(model_x_names_tuple[0], 'my_class_weights'):
pred_r = pred * model_x_names_tuple[0].my_class_weights
pred_r = (pred_r / np.sum(pred_r, axis=1)[:,np.newaxis])
pred = pred_r

pred = pred[:,1]


res = pd.DataFrame({"index": list(X_test_stixel.index), "pred": np.array(pred).flatten()}).set_index("index")

Expand Down

0 comments on commit d9570b3

Please sign in to comment.