Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions ax/analysis/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
# pyre-strict


from typing import Iterable
from typing import Iterable, Sequence

from ax.adapter.base import Adapter

from ax.analysis.analysis import Analysis
from ax.analysis.analysis_card import AnalysisCard
from ax.core.experiment import Experiment
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from pyre_extensions import override
Expand All @@ -36,16 +37,19 @@ class Summary(Analysis):
- **PARAMETER_NAME: The parameter value for the arm, for each parameter
Args:
trial_indices: If specified, only include these trial indices.
trial_status: If specified, only include trials with this status.
omit_empty_columns: If True, omit columns where every value is None.
"""

def __init__(
self,
trial_indices: Iterable[int] | None = None,
trial_status: Sequence[TrialStatus] | None = None,
omit_empty_columns: bool = True,
) -> None:
self.omit_empty_columns = omit_empty_columns
self.trial_indices = trial_indices
self.trial_status = trial_status
self.omit_empty_columns = omit_empty_columns

@override
def compute(
Expand All @@ -66,5 +70,6 @@ def compute(
df=experiment.to_df(
trial_indices=self.trial_indices,
omit_empty_columns=self.omit_empty_columns,
trial_status=self.trial_status,
),
)
55 changes: 55 additions & 0 deletions ax/analysis/tests/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ax.analysis.summary import Summary
from ax.api.client import Client
from ax.api.configs import RangeParameterConfig
from ax.core.base_trial import TrialStatus
from ax.core.trial import Trial
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -170,3 +171,57 @@ def test_trial_indices_filter(self) -> None:
# Test that changes to the experiment are reflected in the summary
client.get_next_trials(max_trials=1)
client.complete_trial(trial_index=1, raw_data={"foo": 2.0})

def test_trial_status_filter(self) -> None:
"""Test that Summary correctly filters by trial_status."""
client = Client()
client.configure_experiment(
name="test_experiment",
parameters=[
RangeParameterConfig(
name="x1",
parameter_type="float",
bounds=(0, 1),
),
],
)
client.configure_optimization(objective="foo")

# Create trials with different statuses
client.get_next_trials(max_trials=1)
client.complete_trial(trial_index=0, raw_data={"foo": 1.0})

client.get_next_trials(max_trials=1)
client.mark_trial_failed(trial_index=1)

client.get_next_trials(max_trials=1)
# Trial 2 remains in RUNNING state

# Test filtering by completed status
analysis = Summary(trial_status=[TrialStatus.COMPLETED])
experiment = client._experiment
card = analysis.compute(experiment=experiment)
self.assertEqual(len(card.df), 1)
self.assertEqual(card.df["trial_index"].iloc[0], 0)
self.assertEqual(card.df["trial_status"].iloc[0], "COMPLETED")

# Test filtering by failed status
analysis = Summary(trial_status=[TrialStatus.FAILED])
card = analysis.compute(experiment=experiment)
self.assertEqual(len(card.df), 1)
self.assertEqual(card.df["trial_index"].iloc[0], 1)
self.assertEqual(card.df["trial_status"].iloc[0], "FAILED")

# Test filtering by running status
analysis = Summary(trial_status=[TrialStatus.RUNNING])
card = analysis.compute(experiment=experiment)
self.assertEqual(len(card.df), 1)
self.assertEqual(card.df["trial_index"].iloc[0], 2)
self.assertEqual(card.df["trial_status"].iloc[0], "RUNNING")

# Test filtering by multiple statuses
analysis = Summary(trial_status=[TrialStatus.COMPLETED, TrialStatus.FAILED])
card = analysis.compute(experiment=experiment)
self.assertEqual(len(card.df), 2)
self.assertIn(0, card.df["trial_index"].values)
self.assertIn(1, card.df["trial_index"].values)
28 changes: 26 additions & 2 deletions ax/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,18 @@ def compute_analyses(
def summarize(
self,
trial_indices: Iterable[int] | None = None,
trial_status: Sequence[
Literal[
"candidate",
"running",
"failed",
"completed",
"abandoned",
"early_stopped",
"staged",
]
]
| None = None,
) -> pd.DataFrame:
"""
Special convenience method for producing the ``DataFrame`` produced by the
Expand All @@ -716,9 +728,21 @@ def summarize(
Experiment's ``runner.run_metadata_report_keys`` field
- **METRIC_NAME: The observed mean of the metric specified, for each metric
- **PARAMETER_NAME: The parameter value for the arm, for each parameter
"""

card = Summary(trial_indices=trial_indices, omit_empty_columns=True).compute(
Args:
trial_indices: If specified, only include these trial indices.
trial_status: If specified, only include trials with this status.
"""
# Convert string literals to TrialStatus enum values
enum_trial_status = None
if trial_status is not None:
enum_trial_status = [TrialStatus[status.upper()] for status in trial_status]

card = Summary(
trial_indices=trial_indices,
trial_status=enum_trial_status,
omit_empty_columns=True,
).compute(
experiment=self._experiment,
generation_strategy=self._maybe_generation_strategy,
)
Expand Down
56 changes: 56 additions & 0 deletions ax/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,62 @@ def test_summarize(self) -> None:
)
pd.testing.assert_frame_equal(summary_df_single, expected_single)

# Test with trial_status parameter
summary_df_completed = client.summarize(trial_status=["completed"])
expected_completed = pd.DataFrame(
{
"trial_index": {0: 1},
"arm_name": {0: "1_0"},
"trial_status": {0: "COMPLETED"},
"generation_node": {0: "CenterOfSearchSpace"},
"foo": {0: 1.0},
"bar": {0: 2.0},
"x1": {0: trial_1_parameters["x1"]},
"x2": {0: trial_1_parameters["x2"]},
}
)
pd.testing.assert_frame_equal(summary_df_completed, expected_completed)

# Test with trial_status parameter for running trials
summary_df_running = client.summarize(trial_status=["running"])
expected_running = pd.DataFrame(
{
"trial_index": {0: 0},
"arm_name": {0: "manual"},
"trial_status": {0: "RUNNING"},
"foo": {0: 0.0},
"bar": {0: 0.5},
"x1": {0: trial_0_parameters["x1"]},
"x2": {0: trial_0_parameters["x2"]},
}
)

assert summary_df_running.equals(expected_running)

# Test with multiple trial_status values
summary_df_multi_status = client.summarize(
trial_status=["completed", "running"]
)
expected_multi_status = pd.DataFrame(
{
"trial_index": {0: 0, 1: 1},
"arm_name": {0: "manual", 1: "1_0"},
"trial_status": {0: "RUNNING", 1: "COMPLETED"},
"generation_node": {0: None, 1: "CenterOfSearchSpace"},
"foo": {0: 0.0, 1: 1.0},
"bar": {0: 0.5, 1: 2.0},
"x1": {
0: trial_0_parameters["x1"],
1: trial_1_parameters["x1"],
},
"x2": {
0: trial_0_parameters["x2"],
1: trial_1_parameters["x2"],
},
}
)
self.assertTrue(summary_df_multi_status.equals(expected_multi_status))

def test_compute_analyses(self) -> None:
client = Client()

Expand Down
8 changes: 7 additions & 1 deletion ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
import warnings
from collections import defaultdict, OrderedDict
from collections.abc import Hashable, Iterable, Mapping
from collections.abc import Hashable, Iterable, Mapping, Sequence
from datetime import datetime
from functools import partial, reduce
from typing import Any, cast, Union
Expand Down Expand Up @@ -1937,6 +1937,7 @@ def metric_config_summary_df(self) -> pd.DataFrame:
def to_df(
self,
trial_indices: Iterable[int] | None = None,
trial_status: Sequence[TrialStatus] | None = None,
omit_empty_columns: bool = True,
) -> pd.DataFrame:
"""
Expand All @@ -1958,6 +1959,7 @@ def to_df(
Args:
trial_indices: If specified, only include these trial indices.
omit_empty_columns: If True, omit columns where every value is None.
trial_status: If specified, only include trials with this status.
"""

records = []
Expand All @@ -1967,6 +1969,10 @@ def to_df(
if trial_indices
else self.trials.values()
)

# Filter trials by status if specified
if trial_status is not None:
trials = [trial for trial in trials if trial.status in trial_status]
# Iterate through trials, and for each trial, iterate through its arms
# and add a record for each arm.
for trial in trials:
Expand Down
66 changes: 66 additions & 0 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,72 @@ def test_to_df(self) -> None:
)
self.assertTrue(df_filtered.equals(expected_filtered_df))

# Test the trial_status parameter
df_status_filtered = experiment.to_df(trial_status=[TrialStatus.COMPLETED])
expected_status_filtered_df = pd.DataFrame.from_dict(
{
"trial_index": [0, 1],
"arm_name": ["0_0", "1_0"],
"trial_status": ["COMPLETED", "COMPLETED"],
"name": ["0", "1"], # the metadata
"m1": [1.0, 3.0],
"m2": [2.0, 4.0],
"x": xs[:2],
"y": ys[:2],
}
)
self.assertTrue(df_status_filtered.equals(expected_status_filtered_df))

# Test with both trial_indices and trial_status parameters
df_both_filtered = experiment.to_df(
trial_indices=[0], trial_status=[TrialStatus.COMPLETED]
)
expected_both_filtered_df = pd.DataFrame.from_dict(
{
"trial_index": [0],
"arm_name": ["0_0"],
"trial_status": ["COMPLETED"],
"name": ["0"], # the metadata
"m1": [1.0],
"m2": [2.0],
"x": [xs[0]],
"y": [ys[0]],
}
)
self.assertTrue(df_both_filtered.equals(expected_both_filtered_df))

# Test the trial_status parameter
# Change the status of trial 2 to RUNNING
experiment.trials[2].mark_running(no_runner_required=True)

# Filter by RUNNING status
df_status_filtered = experiment.to_df(trial_status=[TrialStatus.RUNNING])
expected_status_filtered_df = pd.DataFrame.from_dict(
{
"trial_index": [2],
"arm_name": ["0_0"],
"trial_status": ["RUNNING"],
"x": [xs[2]],
"y": [ys[2]],
}
)
self.assertTrue(df_status_filtered.equals(expected_status_filtered_df))
# Filter by COMPLETED status
df_completed = experiment.to_df(trial_status=[TrialStatus.COMPLETED])
expected_completed_df = pd.DataFrame.from_dict(
{
"trial_index": [0, 1],
"arm_name": ["0_0", "1_0"],
"trial_status": ["COMPLETED", "COMPLETED"],
"name": ["0", "1"], # the metadata
"m1": [1.0, 3.0],
"m2": [2.0, 4.0],
"x": xs[:2],
"y": ys[:2],
}
)
self.assertTrue(df_completed.equals(expected_completed_df))


class ExperimentWithMapDataTest(TestCase):
def setUp(self) -> None:
Expand Down