Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,56 @@ def to_df(self, omit_empty_columns: bool = True) -> pd.DataFrame:
df = df.loc[:, df.notnull().any()]
return df

def add_auxiliary_experiment(
self,
purpose: AuxiliaryExperimentPurpose,
auxiliary_experiment: AuxiliaryExperiment,
) -> None:
"""Add a (non-duplicated) auxiliary experiment to this experiment.

This method adds the auxiliary experiment as the first element in the list
of auxiliary experiments with the specified purpose. If the auxiliary is
already present, it is moved to the first position in the list.

Args:
purpose: The purpose of the auxiliary experiment.
auxiliary_experiment: The auxiliary experiment to add.
"""
if purpose not in self.auxiliary_experiments_by_purpose:
# if no aux experiment, make aux the first one
self.auxiliary_experiments_by_purpose[purpose] = [auxiliary_experiment]
return

# Add or move auxiliary_experiment to be the first element
# Adding to the first and use the order as a default tie-breaker when multiple
# auxiliary experiments are present but only one is going to be used.
self.auxiliary_experiments_by_purpose[purpose] = [auxiliary_experiment] + [
item
for item in self.auxiliary_experiments_by_purpose[purpose]
if item != auxiliary_experiment
]

def find_auxiliary_experiment_by_name(
self,
purpose: AuxiliaryExperimentPurpose,
auxiliary_experiment_name: str,
) -> AuxiliaryExperiment | None:
"""Find the aux experiment with the given name and purpose in the experiment.

Args:
purpose: The purpose of the aux experiment.
auxiliary_experiment_name: The name of the aux experiment.

Returns:
The aux experiment with the given name and purpose, or None if not found.
"""
if purpose not in self.auxiliary_experiments_by_purpose:
return None
for auxiliary_experiment in self.auxiliary_experiments_by_purpose[purpose]:
if auxiliary_experiment.experiment.name == auxiliary_experiment_name:
return auxiliary_experiment
return None

@property
def auxiliary_experiments_by_purpose_for_storage(
self,
Expand Down
58 changes: 58 additions & 0 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,64 @@ def test_experiment_with_aux_experiments(self) -> None:
B_auxiliary_experiment_1,
)

def test_auxiliary_experiment_operations(self) -> None:
"""Test the add_auxiliary_experiment method."""
# Create a base experiment
experiment = get_branin_experiment()

# Create an auxiliary experiment
aux_base_exp = get_branin_experiment()
aux_base_exp.name = "aux_exp"
aux_exp = AuxiliaryExperiment(experiment=aux_base_exp)

aux_exp_found = experiment.find_auxiliary_experiment_by_name(
purpose=AuxiliaryExperimentPurpose.PE_EXPERIMENT,
auxiliary_experiment_name="aux_exp",
)
self.assertIsNone(aux_exp_found)

# Add the auxiliary experiment
experiment.add_auxiliary_experiment(
purpose=AuxiliaryExperimentPurpose.PE_EXPERIMENT,
auxiliary_experiment=aux_exp,
)

# Verify it was added
self.assertEqual(
experiment.auxiliary_experiments_by_purpose[
AuxiliaryExperimentPurpose.PE_EXPERIMENT
][0],
aux_exp,
)

# Add the same auxiliary experiment again
experiment.add_auxiliary_experiment(
purpose=AuxiliaryExperimentPurpose.PE_EXPERIMENT,
auxiliary_experiment=aux_exp,
)

# Verify it wasn't duplicated (should still be just one)
self.assertEqual(
len(
experiment.auxiliary_experiments_by_purpose[
AuxiliaryExperimentPurpose.PE_EXPERIMENT
]
),
1,
)

aux_exp_found = experiment.find_auxiliary_experiment_by_name(
purpose=AuxiliaryExperimentPurpose.PE_EXPERIMENT,
auxiliary_experiment_name="aux_exp",
)
self.assertIs(aux_exp_found, aux_exp)

aux_exp_found = experiment.find_auxiliary_experiment_by_name(
purpose=AuxiliaryExperimentPurpose.BO_EXPERIMENT,
auxiliary_experiment_name="aux_exp",
)
self.assertIsNone(aux_exp_found)

def test_name_and_store_arm_if_not_exists_same_name_different_signature(
self,
) -> None:
Expand Down