Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/multicalibration/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

logger: logging.Logger = logging.getLogger(__name__)

# @oss-disable[end= ]: from multicalibration._compat import DeprecatedAttributesMixin


@dataclass(frozen=True, slots=True)
class MCBoostProcessedData:
Expand Down Expand Up @@ -60,7 +62,11 @@ class EstimationMethod(Enum):
AUTO = 3


class BaseMCBoost(BaseCalibrator, ABC):
class BaseMCBoost(
# @oss-disable[end= ]: DeprecatedAttributesMixin,
BaseCalibrator,
ABC,
):
"""
Abstract base class for MCBoost models. This class hosts the common functionality for all MCBoost models and defines
an abstract interface that all MCBoost models must implement.
Expand Down Expand Up @@ -218,17 +224,17 @@ def __init__(
self._set_lightgbm_params(lightgbm_params)

self.encode_categorical_variables = encode_categorical_variables
self.MONOTONE_T: bool = (
self.monotone_t: bool = (
self.DEFAULT_HYPERPARAMS["monotone_t"] if monotone_t is None else monotone_t
)

self.EARLY_STOPPING: bool = (
self.early_stopping: bool = (
self.DEFAULT_HYPERPARAMS["early_stopping"]
if early_stopping is None
else early_stopping
)

if not self.EARLY_STOPPING:
if not self.early_stopping:
if patience is not None:
raise ValueError(
"`patience` must be None when argument `early_stopping` is disabled."
Expand Down Expand Up @@ -263,7 +269,7 @@ def __init__(
)

if num_rounds is None:
if self.EARLY_STOPPING:
if self.early_stopping:
num_rounds = self.MAX_NUM_ROUNDS_EARLY_STOPPING
else:
num_rounds = self.NUM_ROUNDS_DEFAULT_NO_EARLY_STOPPING
Expand Down Expand Up @@ -575,7 +581,7 @@ def fit(
preprocessed_val_data = None

num_rounds = self.NUM_ROUNDS
if self.EARLY_STOPPING:
if self.early_stopping:
timeout_msg = (
f" (timeout: {self.EARLY_STOPPING_TIMEOUT}s)"
if self.EARLY_STOPPING_TIMEOUT
Expand Down Expand Up @@ -797,7 +803,7 @@ def _predict(

def _get_lgbm_params(self, x: npt.NDArray) -> dict[str, Any]:
lgb_params = self.lightgbm_params.copy()
if self.MONOTONE_T:
if self.monotone_t:
score_constraint = [1]
segment_feature_constraints = [0] * (x.shape[1] - 1)
lgb_params["monotone_constraints"] = (
Expand Down Expand Up @@ -955,7 +961,7 @@ def _determine_best_num_rounds(
if fold_num not in mcboost_per_fold:
mcboost = self._create_instance_for_cv(
encode_categorical_variables=self.encode_categorical_variables,
monotone_t=self.MONOTONE_T,
monotone_t=self.monotone_t,
lightgbm_params=self.lightgbm_params,
early_stopping=False,
num_rounds=0,
Expand Down
4 changes: 2 additions & 2 deletions src/multicalibration/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def plot_learning_curve(
:param show_all: Whether to show all metrics in the learning curve. If False, only the metric specified in MCBoost's early_stopping_score_func is shown.
:returns: A Plotly Figure object representing the learning curve.
"""
if not mcboost_model.EARLY_STOPPING:
if not mcboost_model.early_stopping:
raise ValueError(
"Learning curve can only be plotted for models that have been trained with EARLY_STOPPING=True."
)
Expand All @@ -704,7 +704,7 @@ def plot_learning_curve(
extra_evaluation_due_to_early_stopping = (
1
if (
mcboost_model.EARLY_STOPPING
mcboost_model.early_stopping
and len(mcboost_model.mr) < mcboost_model.NUM_ROUNDS
)
else 0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,7 @@ def dummy_score_func(
effective_num_rounds = len(mcboost.mr)
extra_evaluation_due_to_early_stopping = (
1
if (mcboost.EARLY_STOPPING and effective_num_rounds < mcboost.NUM_ROUNDS)
if (mcboost.early_stopping and effective_num_rounds < mcboost.NUM_ROUNDS)
else 0
)

Expand Down