Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions src/multicalibration/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

logger: logging.Logger = logging.getLogger(__name__)

# @oss-disable[end= ]: from multicalibration._compat import DeprecatedAttributesMixin


@dataclass(frozen=True, slots=True)
class MCBoostProcessedData:
Expand Down Expand Up @@ -60,7 +62,11 @@ class EstimationMethod(Enum):
AUTO = 3


class BaseMCBoost(BaseCalibrator, ABC):
class BaseMCBoost(
# @oss-disable[end= ]: DeprecatedAttributesMixin,
BaseCalibrator,
ABC,
):
"""
Abstract base class for MCBoost models. This class hosts the common functionality for all MCBoost models and defines
an abstract interface that all MCBoost models must implement.
Expand Down Expand Up @@ -218,17 +224,17 @@ def __init__(
self._set_lightgbm_params(lightgbm_params)

self.encode_categorical_variables = encode_categorical_variables
self.MONOTONE_T: bool = (
self.monotone_t: bool = (
self.DEFAULT_HYPERPARAMS["monotone_t"] if monotone_t is None else monotone_t
)

self.EARLY_STOPPING: bool = (
self.early_stopping: bool = (
self.DEFAULT_HYPERPARAMS["early_stopping"]
if early_stopping is None
else early_stopping
)

if not self.EARLY_STOPPING:
if not self.early_stopping:
if patience is not None:
raise ValueError(
"`patience` must be None when argument `early_stopping` is disabled."
Expand All @@ -248,22 +254,22 @@ def __init__(
# Override the timeout when early stopping is disabled
early_stopping_timeout = None

self.EARLY_STOPPING_ESTIMATION_METHOD: EstimationMethod
self.early_stopping_estimation_method: EstimationMethod
if early_stopping_use_crossvalidation is True:
self.EARLY_STOPPING_ESTIMATION_METHOD = EstimationMethod.CROSS_VALIDATION
self.early_stopping_estimation_method = EstimationMethod.CROSS_VALIDATION
elif early_stopping_use_crossvalidation is None:
self.EARLY_STOPPING_ESTIMATION_METHOD = EstimationMethod.AUTO
self.early_stopping_estimation_method = EstimationMethod.AUTO
else:
self.EARLY_STOPPING_ESTIMATION_METHOD = EstimationMethod.HOLDOUT
self.early_stopping_estimation_method = EstimationMethod.HOLDOUT

if self.EARLY_STOPPING_ESTIMATION_METHOD == EstimationMethod.HOLDOUT:
if self.early_stopping_estimation_method == EstimationMethod.HOLDOUT:
if n_folds is not None:
raise ValueError(
"`n_folds` must be None when `early_stopping_use_crossvalidation` is disabled."
)

if num_rounds is None:
if self.EARLY_STOPPING:
if self.early_stopping:
num_rounds = self.MAX_NUM_ROUNDS_EARLY_STOPPING
else:
num_rounds = self.NUM_ROUNDS_DEFAULT_NO_EARLY_STOPPING
Expand All @@ -274,11 +280,11 @@ def __init__(
self.DEFAULT_HYPERPARAMS["patience"] if patience is None else patience
)

self.EARLY_STOPPING_TIMEOUT: int | None = early_stopping_timeout
self.early_stopping_timeout: int | None = early_stopping_timeout

self.N_FOLDS: int = (
self.n_folds: int = (
1 # Because we make a single train/test split when using holdout
if (self.EARLY_STOPPING_ESTIMATION_METHOD == EstimationMethod.HOLDOUT)
if (self.early_stopping_estimation_method == EstimationMethod.HOLDOUT)
else self.DEFAULT_HYPERPARAMS["n_folds"]
if n_folds is None
else n_folds
Expand Down Expand Up @@ -575,10 +581,10 @@ def fit(
preprocessed_val_data = None

num_rounds = self.NUM_ROUNDS
if self.EARLY_STOPPING:
if self.early_stopping:
timeout_msg = (
f" (timeout: {self.EARLY_STOPPING_TIMEOUT}s)"
if self.EARLY_STOPPING_TIMEOUT
f" (timeout: {self.early_stopping_timeout}s)"
if self.early_stopping_timeout
else ""
)
logger.info(
Expand Down Expand Up @@ -797,7 +803,7 @@ def _predict(

def _get_lgbm_params(self, x: npt.NDArray) -> dict[str, Any]:
lgb_params = self.lightgbm_params.copy()
if self.MONOTONE_T:
if self.monotone_t:
score_constraint = [1]
segment_feature_constraints = [0] * (x.shape[1] - 1)
lgb_params["monotone_constraints"] = (
Expand Down Expand Up @@ -879,7 +885,7 @@ def _determine_n_folds(
estimation_method: EstimationMethod,
) -> int:
if estimation_method == EstimationMethod.CROSS_VALIDATION:
n_folds = self.N_FOLDS
n_folds = self.n_folds
logger.info(f"Using {n_folds} folds for cross-validation.")
else:
n_folds = 1
Expand Down Expand Up @@ -919,11 +925,11 @@ def _determine_best_num_rounds(
log_add = " (input prediction for early stopping baseline)"
logger.info(f"Evaluating round {num_rounds}{log_add}")

if self.EARLY_STOPPING_TIMEOUT is not None and self._get_elapsed_time(
if self.early_stopping_timeout is not None and self._get_elapsed_time(
start_time
) > cast(int, self.EARLY_STOPPING_TIMEOUT):
) > cast(int, self.early_stopping_timeout):
logger.warning(
f"Stopping early stopping upon exceeding the {self.EARLY_STOPPING_TIMEOUT:,}-second timeout; "
f"Stopping early stopping upon exceeding the {self.early_stopping_timeout:,}-second timeout; "
+ "MCBoost results will likely improve by increasing `early_stopping_timeout` or setting it to None"
)
break
Expand Down Expand Up @@ -955,7 +961,7 @@ def _determine_best_num_rounds(
if fold_num not in mcboost_per_fold:
mcboost = self._create_instance_for_cv(
encode_categorical_variables=self.encode_categorical_variables,
monotone_t=self.MONOTONE_T,
monotone_t=self.monotone_t,
lightgbm_params=self.lightgbm_params,
early_stopping=False,
num_rounds=0,
Expand Down Expand Up @@ -1201,8 +1207,8 @@ def _determine_estimation_method(self, weights: npt.NDArray) -> EstimationMethod

:return: the estimation method to use.
"""
if self.EARLY_STOPPING_ESTIMATION_METHOD != EstimationMethod.AUTO:
return self.EARLY_STOPPING_ESTIMATION_METHOD
if self.early_stopping_estimation_method != EstimationMethod.AUTO:
return self.early_stopping_estimation_method

if self.early_stopping_score_func.name != "log_loss":
# Automatically infer the estimation method only when using the logistic loss, otherwise use k-fold.
Expand Down Expand Up @@ -1330,7 +1336,7 @@ def _check_labels(df_train: pd.DataFrame, label_column_name: str) -> None:
@property
def _cv_splitter(self) -> StratifiedKFold:
return StratifiedKFold(
n_splits=self.N_FOLDS,
n_splits=self.n_folds,
shuffle=True,
random_state=self._next_seed(),
)
Expand Down Expand Up @@ -1446,7 +1452,7 @@ def _check_labels(df_train: pd.DataFrame, label_column_name: str) -> None:
@property
def _cv_splitter(self) -> KFold:
return KFold(
n_splits=self.N_FOLDS,
n_splits=self.n_folds,
shuffle=True,
random_state=self._next_seed(),
)
Expand Down
4 changes: 2 additions & 2 deletions src/multicalibration/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def plot_learning_curve(
:param show_all: Whether to show all metrics in the learning curve. If False, only the metric specified in MCBoost's early_stopping_score_func is shown.
:returns: A Plotly Figure object representing the learning curve.
"""
if not mcboost_model.EARLY_STOPPING:
if not mcboost_model.early_stopping:
raise ValueError(
"Learning curve can only be plotted for models that have been trained with EARLY_STOPPING=True."
)
Expand All @@ -704,7 +704,7 @@ def plot_learning_curve(
extra_evaluation_due_to_early_stopping = (
1
if (
mcboost_model.EARLY_STOPPING
mcboost_model.early_stopping
and len(mcboost_model.mr) < mcboost_model.NUM_ROUNDS
)
else 0
Expand Down
2 changes: 1 addition & 1 deletion src/multicalibration/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def tune_mcboost_params(
assert df_val is not None

if (
model.EARLY_STOPPING_ESTIMATION_METHOD
model.early_stopping_estimation_method
== methods.EstimationMethod.CROSS_VALIDATION
and (pass_df_val_into_tuning or pass_df_val_into_final_fit)
):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,7 @@ def dummy_score_func(
effective_num_rounds = len(mcboost.mr)
extra_evaluation_due_to_early_stopping = (
1
if (mcboost.EARLY_STOPPING and effective_num_rounds < mcboost.NUM_ROUNDS)
if (mcboost.early_stopping and effective_num_rounds < mcboost.NUM_ROUNDS)
else 0
)

Expand Down Expand Up @@ -2493,7 +2493,7 @@ def test_determine_n_folds_returns_correct_value(
# Assert: Verify correct n_folds is returned
# Special handling for CROSS_VALIDATION since N_FOLDS may be set differently
if estimation_method == methods.EstimationMethod.CROSS_VALIDATION:
assert n_folds == model.N_FOLDS
assert n_folds == model.n_folds
else:
assert n_folds == expected_n_folds

Expand Down
2 changes: 1 addition & 1 deletion tests/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def sample_val_data(rng):
def mock_mcboost_model(rng):
model = Mock(spec=methods.MCBoost)
model.predict = Mock(return_value=rng.uniform(0.1, 0.9, 80))
model.EARLY_STOPPING_ESTIMATION_METHOD = methods.EstimationMethod.HOLDOUT
model.early_stopping_estimation_method = methods.EstimationMethod.HOLDOUT
return model


Expand Down