Skip to content

Commit

Permalink
dont plot training scores for AUC in CatBoost (mljar#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
pplonski committed Mar 31, 2021
1 parent ff82af8 commit c461107
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
33 changes: 21 additions & 12 deletions supervised/algorithms/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,23 +230,30 @@ def fit(
self.best_ntree_limit = self.model.tree_count_

if log_to_file is not None:
train_scores = self.model.evals_result_["learn"][self.log_metric_name]
validation_scores = self.model.evals_result_["validation"][
train_scores = self.model.evals_result_["learn"].get(self.log_metric_name)
validation_scores = self.model.evals_result_["validation"].get(
self.log_metric_name
]
)
if model_init is not None:
train_scores = (
model_init.evals_result_["learn"][self.log_metric_name]
+ train_scores
)
validation_scores = (
model_init.evals_result_["validation"][self.log_metric_name]
+ validation_scores
)
if train_scores is not None:
train_scores = (
model_init.evals_result_["learn"].get(self.log_metric_name)
+ train_scores
)
if validation_scores is not None:
validation_scores = (
model_init.evals_result_["validation"].get(self.log_metric_name)
+ validation_scores
)
iteration = None
if train_scores is not None:
iteration = range(len(validation_scores))
elif validation_scores is not None:
iteration = range(len(validation_scores))

result = pd.DataFrame(
{
"iteration": range(len(train_scores)),
"iteration": iteration,
"train": train_scores,
"validation": validation_scores,
}
Expand Down Expand Up @@ -295,6 +302,8 @@ def get_metric_name(self):
return None
if metric == "Logloss":
return "logloss"
elif metric == "AUC":
return "auc"
elif metric == "MultiClass":
return "logloss"
elif metric == "RMSE":
Expand Down
16 changes: 9 additions & 7 deletions supervised/utils/learning_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,15 @@ def plot_iterations(
repeat_str = f" Reapeat {repeat+1}," if repeat is not None else ""
# if trees_in_iteration is not None:
# df.iteration = df.iteration * trees_in_iteration
plt.plot(
df.iteration,
df.train,
"--",
color=colors[fold],
label=f"Fold {fold+1},{repeat_str} train",
)
any_none = np.sum(pd.isnull(df.train))
if any_none == 0:
plt.plot(
df.iteration,
df.train,
"--",
color=colors[fold],
label=f"Fold {fold+1},{repeat_str} train",
)
any_none = np.sum(pd.isnull(df.test))
if any_none == 0:
plt.plot(
Expand Down

0 comments on commit c461107

Please sign in to comment.