Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,19 @@ def _lookup_and_validate_data(
map_df = data.map_df
# keep only relevant metrics
map_df = map_df[map_df["metric_signature"].isin(metric_signatures)].copy()

# Drop rows with NaN values in MAP_KEY column to prevent issues in
# align_partial_results which uses MAP_KEY as the pivot index
nan_mask = map_df[MAP_KEY].isna()
if nan_mask.any():
num_nan_rows = nan_mask.sum()
nan_trial_indices = map_df.loc[nan_mask, "trial_index"].unique().tolist()
logger.warning(
f"Dropped {num_nan_rows} row(s) with NaN values in the progression "
f"column ('{MAP_KEY}') for trial(s) {nan_trial_indices}."
)
map_df = map_df[~nan_mask]

if self.normalize_progressions:
values = map_df[MAP_KEY].astype(float)
map_df[MAP_KEY] = values / values.abs().max()
Expand Down
51 changes: 51 additions & 0 deletions ax/early_stopping/tests/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,57 @@ def test_early_stopping_strategy(self) -> None:
# `BaseEarlyStoppingStrategy`.
BaseEarlyStoppingStrategy()

def test_nan_map_key_values_dropped_with_warning(self) -> None:
"""Test that NaN values in MAP_KEY column are dropped with a warning."""
experiment = get_test_map_data_experiment(
num_trials=3, num_fetches=5, num_complete=3
)
es_strategy = FakeStrategy()
metric_signature, _ = es_strategy._default_objective_and_direction(
experiment=experiment
)

# Get the data and introduce NaN values in MAP_KEY column
data = assert_is_instance(experiment.lookup_data(), MapData)
modified_df = data.map_df.copy()

# Set some MAP_KEY values to NaN for specific trials
# This simulates corrupted or missing progression data
# Use metric_signature to match the filter in _lookup_and_validate_data
trial_0_mask = (modified_df["trial_index"] == 0) & (
modified_df["metric_signature"] == metric_signature
)
# Set rows at the first index where trial_0_mask is True to have NaN in MAP_KEY
first_trial_0_idx = modified_df.loc[trial_0_mask].index[0]
modified_df.loc[first_trial_0_idx, MAP_KEY] = float("nan")

# Attach modified data with NaN values
modified_data = MapData(df=modified_df)
experiment.attach_data(data=modified_data)

# Verify warning is logged when NaN values are dropped
with patch.object(logger, "warning") as mock_warning:
result = es_strategy._lookup_and_validate_data(
experiment, metric_signatures=[metric_signature]
)

# Verify warning was called with appropriate message
mock_warning.assert_called_once()

warning_msg, *_ = mock_warning.call_args.args
self.assertRegex(
warning_msg,
r"Dropped 1 row\(s\) with NaN values in the progression column "
rf"\('{MAP_KEY}'\) for trial\(s\) \[0\]\.",
)

# Verify result is not None and NaN rows are dropped
self.assertIsNotNone(result)
result = assert_is_instance(result, MapData)

# Verify no NaN values remain in MAP_KEY column
self.assertFalse(result.map_df[MAP_KEY].isna().any())

def test_all_objectives_and_directions_raises_error_when_lower_is_better_is_none(
self,
) -> None:
Expand Down