Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions ax/analysis/plotly/surface/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from ax.analysis.plotly.surface.utils import (
get_features_for_slice_or_contour,
get_fixed_values_for_slice_or_contour,
get_parameter_values,
is_axis_log_scale,
)
Expand All @@ -29,7 +30,7 @@
validate_experiment,
)
from ax.core.experiment import Experiment
from ax.core.parameter import DerivedParameter
from ax.core.parameter import DerivedParameter, TParamValue
from ax.core.trial_status import STATUSES_EXPECTING_DATA
from ax.core.utils import get_target_trial_index
from ax.exceptions.core import UserInputError
Expand All @@ -43,16 +44,19 @@
"These plots show the relationship between a metric and two parameters. They "
"show the predicted values of the metric (indicated by color) as a function of "
"the two parameters on the x- and y-axes while keeping all other parameters "
"fixed at their status_quo value (or mean value if status_quo is unavailable). "
"fixed at their status_quo value (if available), best trial value, or the "
"center of the search space."
)


@final
class ContourPlot(Analysis):
"""
Plot a 2D surface of the surrogate model's predicted outcomes for a given pair of
parameters, where all other parameters are held fixed at their status-quo value or
mean if no status quo is available.
parameters, where all other parameters are held fixed at their status_quo value
(if available and within the search space), otherwise at their values from the best
trial (for single-objective optimization), or at the center of the search space if
neither is available.

The DataFrame computed will contain the following columns:
- PARAMETER_NAME: The value of the x parameter specified
Expand Down Expand Up @@ -152,13 +156,20 @@ def compute(

metric_name = self.metric_name or select_metric(experiment=experiment)

# Get fixed parameter values and description
fixed_values, fixed_values_description = get_fixed_values_for_slice_or_contour(
experiment=experiment,
generation_strategy=generation_strategy,
)

df = _prepare_data(
experiment=experiment,
model=relevant_adapter,
x_parameter_name=self.x_parameter_name,
y_parameter_name=self.y_parameter_name,
metric_name=metric_name,
relativize=self.relativize,
fixed_values=fixed_values,
)

fig = _prepare_plot(
Expand Down Expand Up @@ -187,13 +198,12 @@ def compute(
subtitle=(
"The contour plot visualizes the predicted outcomes "
f"for {metric_name} across a two-dimensional parameter space, "
"with other parameters held fixed at their status_quo value "
"(or mean value if status_quo is unavailable). This plot helps "
"in identifying regions of optimal performance and understanding "
"how changes in the selected parameters influence the predicted "
"outcomes. Contour lines represent levels of constant predicted "
"values, providing insights into the gradient and potential optima "
"within the parameter space."
f"with other parameters held fixed at {fixed_values_description}. "
"This plot helps in identifying regions of optimal performance and "
"understanding how changes in the selected parameters influence the "
"predicted outcomes. Contour lines represent levels of constant "
"predicted values, providing insights into the gradient and potential "
"optima within the parameter space."
),
df=df,
fig=fig,
Expand All @@ -215,7 +225,8 @@ def compute_contour_adhoc(
a notebook setting.

Args:
parameter_name: The name of the parameter to plot on the x-axis.
x_parameter_name: The name of the parameter to plot on the x-axis.
y_parameter_name: The name of the parameter to plot on the y-axis.
experiment: The experiment to source data from.
generation_strategy: Optional. The generation strategy to extract the adapter
from.
Expand Down Expand Up @@ -247,6 +258,7 @@ def _prepare_data(
y_parameter_name: str,
metric_name: str,
relativize: bool,
fixed_values: dict[str, TParamValue],
) -> pd.DataFrame:
trials = experiment.extract_relevant_trials(trial_statuses=STATUSES_EXPECTING_DATA)
sampled = [
Expand Down Expand Up @@ -279,14 +291,15 @@ def _prepare_data(
ys = [*[sample["y_parameter_name"] for sample in sampled], *unsampled_ys]

# Construct observation features for each parameter value previously chosen by
# fixing all other parameters to their status-quo value or mean.
features = features = [
# fixing all other parameters to values from get_fixed_values_for_slice_or_contour.
features = [
get_features_for_slice_or_contour(
parameters={
x_parameter_name: x,
y_parameter_name: y,
},
search_space=experiment.search_space,
fixed_values=fixed_values,
)
for x in xs
for y in ys
Expand Down
31 changes: 22 additions & 9 deletions ax/analysis/plotly/surface/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from ax.analysis.plotly.surface.utils import (
get_features_for_slice_or_contour,
get_fixed_values_for_slice_or_contour,
get_parameter_values,
is_axis_log_scale,
)
Expand All @@ -35,7 +36,7 @@
)
from ax.core.analysis_card import AnalysisCardBase
from ax.core.experiment import Experiment
from ax.core.parameter import DerivedParameter
from ax.core.parameter import DerivedParameter, TParamValue
from ax.core.trial_status import STATUSES_EXPECTING_DATA
from ax.core.utils import get_target_trial_index
from ax.exceptions.core import UserInputError
Expand All @@ -49,21 +50,24 @@
"These plots show the relationship between a metric and a parameter. They "
"show the predicted values of the metric on the y-axis as a function of the "
"parameter on the x-axis while keeping all other parameters fixed at their "
"status_quo value (or mean value if status_quo is unavailable). "
"status_quo value (if available), best trial value, or the center of the "
"search space."
)


@final
class SlicePlot(Analysis):
"""
Plot a 1D "slice" of the surrogate model's predicted outcomes for a given
parameter, where all other parameters are held fixed at their status-quo value or
mean if no status quo is available.
parameter, where all other parameters are held fixed at their status_quo value
(if available and within the search space), otherwise at their values from the best
trial (for single-objective optimization), or at the center of the search space if
neither is available.

The DataFrame computed will contain the following columns:
- PARAMETER_NAME: The value of the parameter specified
- METRIC_NAME_mean: The predected mean of the metric specified
- METRIC_NAME_sem: The predected sem of the metric specified
- METRIC_NAME_mean: The predicted mean of the metric specified
- METRIC_NAME_sem: The predicted sem of the metric specified
- sampled: Whether the parameter value was sampled in at least one trial
"""

Expand Down Expand Up @@ -147,12 +151,19 @@ def compute(

metric_name = self.metric_name or select_metric(experiment=experiment)

# Get fixed parameter values and description
fixed_values, fixed_values_description = get_fixed_values_for_slice_or_contour(
experiment=experiment,
generation_strategy=generation_strategy,
)

df = _prepare_data(
experiment=experiment,
model=relevant_adapter,
parameter_name=self.parameter_name,
metric_name=metric_name,
relativize=self.relativize,
fixed_values=fixed_values,
)

fig = _prepare_plot(
Expand All @@ -172,8 +183,8 @@ def compute(
subtitle=(
"The slice plot provides a one-dimensional view of predicted "
f"outcomes for {metric_name} as a function of a single parameter, "
"while keeping all other parameters fixed at their status_quo "
"value (or mean value if status_quo is unavailable). "
f"while keeping all other parameters fixed at "
f"{fixed_values_description}. "
"This visualization helps in understanding the sensitivity and "
"impact of changes in the selected parameter on the predicted "
"metric outcomes."
Expand Down Expand Up @@ -229,6 +240,7 @@ def _prepare_data(
parameter_name: str,
metric_name: str,
relativize: bool,
fixed_values: dict[str, TParamValue],
) -> pd.DataFrame:
trials = experiment.extract_relevant_trials(trial_statuses=STATUSES_EXPECTING_DATA)
sampled_xs = [
Expand All @@ -252,11 +264,12 @@ def _prepare_data(
xs = [*[sample["parameter_value"] for sample in sampled_xs], *unsampled_xs]

# Construct observation features for each parameter value previously chosen by
# fixing all other parameters to their status-quo value or mean.
# fixing all other parameters to values from get_fixed_values_for_slice_or_contour.
features = [
get_features_for_slice_or_contour(
parameters={parameter_name: x},
search_space=experiment.search_space,
fixed_values=fixed_values,
)
for x in xs
]
Expand Down
18 changes: 7 additions & 11 deletions ax/analysis/plotly/surface/tests/test_contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,11 @@ def setUp(self) -> None:
"bar": parameterization["x"] ** 2 + parameterization["y"] ** 2
},
)
self.expected_subtitle = (
self.expected_subtitle_contains = [
"The contour plot visualizes the predicted outcomes "
"for bar across a two-dimensional parameter space, "
"with other parameters held fixed at their status_quo value "
"(or mean value if status_quo is unavailable). This plot helps "
"in identifying regions of optimal performance and understanding "
"how changes in the selected parameters influence the predicted "
"outcomes. Contour lines represent levels of constant predicted "
"values, providing insights into the gradient and potential optima "
"within the parameter space."
)
"with other parameters held fixed at their best trial value",
]
self.expected_title = "bar (Mean) vs. x, y"
self.expected_name = "ContourPlot"
self.expected_cols = {
Expand Down Expand Up @@ -101,7 +95,8 @@ def test_compute(self) -> None:
self.expected_name,
)
self.assertEqual(card.title, self.expected_title)
self.assertEqual(card.subtitle, self.expected_subtitle)
for expected_text in self.expected_subtitle_contains:
self.assertIn(expected_text, card.subtitle)
self.assertEqual(
{*card.df.columns},
self.expected_cols,
Expand Down Expand Up @@ -144,7 +139,8 @@ def test_compute_adhoc(self) -> None:
self.expected_name,
)
self.assertEqual(card.title, self.expected_title)
self.assertEqual(card.subtitle, self.expected_subtitle)
for expected_text in self.expected_subtitle_contains:
self.assertIn(expected_text, card.subtitle)
self.assertEqual({*card.df.columns}, self.expected_cols)
self.assertIsNotNone(card.blob)

Expand Down
26 changes: 6 additions & 20 deletions ax/analysis/plotly/surface/tests/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,10 @@ def test_compute(self) -> None:
"SlicePlot",
)
self.assertEqual(card.title, "bar vs. x")
self.assertEqual(
# Subtitle should mention "their best trial value"
self.assertIn(
"while keeping all other parameters fixed at their best trial value",
card.subtitle,
(
"The slice plot provides a one-dimensional view of predicted "
"outcomes for bar as a function of a single parameter, "
"while keeping all other parameters fixed at their status_quo "
"value (or mean value if status_quo is unavailable). "
"This visualization helps in understanding the sensitivity and "
"impact of changes in the selected parameter on the predicted "
"metric outcomes."
),
)
self.assertEqual(
{*card.df.columns},
Expand Down Expand Up @@ -109,17 +102,10 @@ def test_compute_adhoc(self) -> None:
"SlicePlot",
)
self.assertEqual(card.title, "bar vs. x")
self.assertEqual(
# Subtitle should mention "their best trial value"
self.assertIn(
"while keeping all other parameters fixed at their best trial value",
card.subtitle,
(
"The slice plot provides a one-dimensional view of predicted "
"outcomes for bar as a function of a single parameter, "
"while keeping all other parameters fixed at their status_quo "
"value (or mean value if status_quo is unavailable). "
"This visualization helps in understanding the sensitivity and "
"impact of changes in the selected parameter on the predicted "
"metric outcomes."
),
)
self.assertEqual(
{*card.df.columns},
Expand Down
Loading