Skip to content

Commit

Permalink
ROC curves for binary classification tasks (mljar#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
pplonski committed Apr 15, 2021
1 parent 0a9ee65 commit 2ea3e37
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions supervised/utils/additional_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,16 @@ def save_binary_classification(
AdditionalMetrics.add_permutation_importance(
fout, model_path, fold_cnt, repeat_cnt
)
AdditionalMetrics.add_shap_importance(
fout, model_path, fold_cnt, repeat_cnt
)
AdditionalMetrics.add_shap_binary(fout, model_path, fold_cnt, repeat_cnt)

plots = additional_metrics.get("additional_plots")
if plots is not None:
AdditionalPlots.append(fout, model_path, plots)

AdditionalMetrics.add_shap_importance(
fout, model_path, fold_cnt, repeat_cnt
)
AdditionalMetrics.add_shap_binary(fout, model_path, fold_cnt, repeat_cnt)

fout.write("\n\n[<< Go back](../README.md)\n")

@staticmethod
Expand All @@ -355,17 +356,18 @@ def save_multiclass_classification(
AdditionalMetrics.add_permutation_importance(
fout, model_path, fold_cnt, repeat_cnt
)

plots = additional_metrics.get("additional_plots")
if plots is not None:
AdditionalPlots.append(fout, model_path, plots)

AdditionalMetrics.add_shap_importance(
fout, model_path, fold_cnt, repeat_cnt
)
AdditionalMetrics.add_shap_multiclass(
fout, model_path, fold_cnt, repeat_cnt
)

plots = additional_metrics.get("additional_plots")
if plots is not None:
AdditionalPlots.append(fout, model_path, plots)

fout.write("\n\n[<< Go back](../README.md)\n")

@staticmethod
Expand All @@ -386,17 +388,18 @@ def save_regression(
AdditionalMetrics.add_permutation_importance(
fout, model_path, fold_cnt, repeat_cnt
)

plots = additional_metrics.get("additional_plots")
if plots is not None:
AdditionalPlots.append(fout, model_path, plots)

AdditionalMetrics.add_shap_importance(
fout, model_path, fold_cnt, repeat_cnt
)
AdditionalMetrics.add_shap_regression(
fout, model_path, fold_cnt, repeat_cnt
)

plots = additional_metrics.get("additional_plots")
if plots is not None:
AdditionalPlots.append(fout, model_path, plots)

fout.write("\n\n[<< Go back](../README.md)\n")

@staticmethod
Expand Down

0 comments on commit 2ea3e37

Please sign in to comment.