Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/multicalibration/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,15 +696,15 @@ def plot_learning_curve(
"""
if not mcboost_model.early_stopping:
raise ValueError(
"Learning curve can only be plotted for models that have been trained with EARLY_STOPPING=True."
"Learning curve can only be plotted for models that have been trained with early_stopping=True."
)

performance_metrics = mcboost_model._performance_metrics
extra_evaluation_due_to_early_stopping = (
1
if (
mcboost_model.early_stopping
and len(mcboost_model.mr) < mcboost_model.NUM_ROUNDS
and len(mcboost_model.mr) < mcboost_model.num_rounds
)
else 0
)
Expand All @@ -713,8 +713,8 @@ def plot_learning_curve(
1
+ len(mcboost_model.mr)
+ extra_evaluation_due_to_early_stopping
+ mcboost_model.PATIENCE,
1 + mcboost_model.NUM_ROUNDS,
+ mcboost_model.patience,
1 + mcboost_model.num_rounds,
)
x_vals = np.arange(0, tot_num_rounds)
metric_names = [mcboost_model.early_stopping_score_func.name]
Expand Down
97 changes: 96 additions & 1 deletion tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import pytest

from multicalibration import metrics, plotting
from multicalibration import methods, metrics, plotting


@pytest.fixture
Expand Down Expand Up @@ -436,6 +436,101 @@ def test_plot_calibration_curve_by_segment_empty_data():
assert fig is not None


def test_plot_learning_curve_with_early_stopping(rng):
n_samples = 200
predictions = rng.rand(n_samples)
labels = rng.randint(0, 2, n_samples)

df = pd.DataFrame(
{
"prediction": predictions,
"label": labels,
"feature": rng.choice(["a", "b", "c"], n_samples),
}
)

model = methods.MCBoost(
num_rounds=3,
early_stopping=True,
patience=1,
lightgbm_params={"max_depth": 2, "n_estimators": 2},
)
model.fit(
df_train=df,
prediction_column_name="prediction",
label_column_name="label",
categorical_feature_column_names=["feature"],
)

fig = plotting.plot_learning_curve(model)

assert fig is not None


def test_plot_learning_curve_raises_without_early_stopping(rng):
n_samples = 100
predictions = rng.rand(n_samples)
labels = rng.randint(0, 2, n_samples)

df = pd.DataFrame(
{
"prediction": predictions,
"label": labels,
"feature": rng.choice(["a", "b"], n_samples),
}
)

model = methods.MCBoost(
num_rounds=2,
early_stopping=False,
lightgbm_params={"max_depth": 2, "n_estimators": 2},
)
model.fit(
df_train=df,
prediction_column_name="prediction",
label_column_name="label",
categorical_feature_column_names=["feature"],
)

with pytest.raises(
ValueError,
match="Learning curve can only be plotted for models that have been trained with early_stopping=True",
):
plotting.plot_learning_curve(model)


def test_plot_learning_curve_with_show_all(rng):
n_samples = 200
predictions = rng.rand(n_samples)
labels = rng.randint(0, 2, n_samples)

df = pd.DataFrame(
{
"prediction": predictions,
"label": labels,
"feature": rng.choice(["a", "b", "c"], n_samples),
}
)

model = methods.MCBoost(
num_rounds=3,
early_stopping=True,
patience=1,
save_training_performance=True,
lightgbm_params={"max_depth": 2, "n_estimators": 2},
)
model.fit(
df_train=df,
prediction_column_name="prediction",
label_column_name="label",
categorical_feature_column_names=["feature"],
)

fig = plotting.plot_learning_curve(model, show_all=True)

assert fig is not None


def test_plot_score_distribution_does_not_modify_input_dataframe(sample_data):
fig, ax = plt.subplots()
df = pd.DataFrame({"col": [1, 2, 3]})
Expand Down