Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 43 additions & 15 deletions ax/adapter/transforms/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ax.core.search_space import SearchSpace
from ax.core.types import TNumeric, TParamValue
from ax.generators.types import TConfig
from ax.utils.common.typeutils import assert_is_instance_of_tuple
from pyre_extensions import assert_is_instance

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,7 +50,14 @@ def __init__(
self.transform_parameters: dict[str, ParameterType] = {
p_name: p.parameter_type
for p_name, p in search_space.parameters.items()
if isinstance(p, RangeParameter) and p.is_numeric and p.log_scale
if isinstance(p, (RangeParameter, ChoiceParameter)) and p.log_scale
}
# For choice parameters, store the original values so that we can
# match them exactly when untransforming.
self.original_values: dict[str, list[TParamValue]] = {
p_name: p.values
for p_name, p in search_space.parameters.items()
if isinstance(p, ChoiceParameter) and p.log_scale
}

def transform_observation_features(
Expand Down Expand Up @@ -81,29 +89,41 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
elif (
isinstance(p, RangeParameter)
and p.parameter_type == ParameterType.INT
):
# Convert integer valued RangeParameter to ChoiceParameter
lower = assert_is_instance(p.lower, int)
upper = assert_is_instance(p.upper, int)
values = list(range(lower, upper + 1))
) or isinstance(p, ChoiceParameter):
# Handle both int RangeParameter and ChoiceParameter
# by converting to log-transformed ChoiceParameter
if isinstance(p, RangeParameter):
lower = assert_is_instance(p.lower, int)
upper = assert_is_instance(p.upper, int)
values = list(range(lower, upper + 1))
is_ordered = True
sort_values = True
else: # ChoiceParameter
values = p.values
is_ordered = p.is_ordered
sort_values = p.sort_values

# Apply log10 transformation
transformed_values = [
assert_is_instance(math.log10(v), TParamValue) for v in values
assert_is_instance(math.log10(float(v)), TParamValue)
for v in values
]

target_value = p.target_value
if target_value is not None:
target_value = assert_is_instance(target_value, int)
target_value = math.log10(target_value)
target_value = math.log10(
assert_is_instance_of_tuple(target_value, (float, int))
)

# Create new ChoiceParameter to replace the RangeParameter
# Create new ChoiceParameter with transformed values.
choice_param = ChoiceParameter(
name=p.name,
parameter_type=ParameterType.FLOAT,
values=transformed_values,
is_ordered=True,
is_ordered=is_ordered,
is_fidelity=p.is_fidelity,
target_value=target_value,
sort_values=True,
sort_values=sort_values,
log_scale=False,
bypass_cardinality_check=True,
)

Expand All @@ -120,8 +140,16 @@ def untransform_observation_features(
param: float = assert_is_instance(obsf.parameters[p_name], float)
val = math.pow(10, param)

# Round to nearest integer if original parameter type is int
if p_type == ParameterType.INT:
# Match original values exactly for ChoiceParameter.
if p_name in self.original_values:
val = assert_is_instance_of_tuple(
min(
self.original_values[p_name], key=lambda x: abs(x - val)
),
(float, int),
)
# Round to nearest integer for integer-RangeParameter.
elif p_type == ParameterType.INT:
val = round(val)

obsf.parameters[p_name] = val
Expand Down
128 changes: 127 additions & 1 deletion ax/adapter/transforms/tests/test_log_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
import numpy as np
from ax.adapter.base import DataLoaderConfig
from ax.adapter.data_utils import extract_experiment_data
from ax.adapter.torch import TorchAdapter
from ax.adapter.transforms.choice_encode import ChoiceToNumericChoice
from ax.adapter.transforms.log import Log
from ax.core.observation import ObservationFeatures
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment_with_observations
from ax.utils.testing.mock import mock_botorch_optimize
from pandas.testing import assert_frame_equal, assert_series_equal
from pyre_extensions import assert_is_instance
from pyre_extensions import assert_is_instance, none_throws


class LogTransformTest(TestCase):
Expand Down Expand Up @@ -185,3 +189,125 @@ def test_transform_experiment_data(self) -> None:
assert_frame_equal(
transformed_data.observation_data, experiment_data.observation_data
)

def test_log_scale_choice_parameter(self) -> None:
"""Test log-scale ChoiceParameter support"""
# Search space with log-scale ChoiceParameter
search_space = SearchSpace(
parameters=[
ChoiceParameter(
"z",
parameter_type=ParameterType.FLOAT,
values=[1.0, 10.0, 100.0, 1000.0],
log_scale=True,
),
ChoiceParameter(
"w",
parameter_type=ParameterType.INT,
values=[2, 4, 8, 16, 32],
log_scale=True,
is_fidelity=True,
target_value=32,
),
]
)
t = Log(search_space=search_space)

# Test that log-scale choice parameters are identified
self.assertEqual(
t.transform_parameters,
{"z": ParameterType.FLOAT, "w": ParameterType.INT},
)
self.assertEqual(
t.original_values, {"z": [1.0, 10.0, 100.0, 1000.0], "w": [2, 4, 8, 16, 32]}
)

# Test observation features transformation
observation_features = [ObservationFeatures(parameters={"z": 100.0, "w": 8})]
obs_ft2 = deepcopy(observation_features)
obs_ft2 = t.transform_observation_features(obs_ft2)
self.assertEqual(
obs_ft2,
[
ObservationFeatures(
parameters={
"z": math.log10(100.0),
"w": math.log10(8),
}
)
],
)

# Test untransformation - should get exact match for the original values.
obs_ft2 = t.untransform_observation_features(obs_ft2)
self.assertEqual(obs_ft2, observation_features)
self.assertTrue(isinstance(obs_ft2[0].parameters["w"], int))

# Test search space transformation
ss2 = deepcopy(search_space)
ss2 = t.transform_search_space(ss2)

# Test float log-scale choice parameter transformation
param_z = assert_is_instance(ss2.parameters["z"], ChoiceParameter)
self.assertEqual(param_z.parameter_type, ParameterType.FLOAT)
expected_values_z = [math.log10(v) for v in [1.0, 10.0, 100.0, 1000.0]]
self.assertEqual(param_z.values, expected_values_z)
self.assertFalse(param_z.log_scale)

# Test int log-scale choice parameter transformation
param_w = assert_is_instance(ss2.parameters["w"], ChoiceParameter)
self.assertEqual(param_w.parameter_type, ParameterType.FLOAT)
expected_values_w = [math.log10(v) for v in [2, 4, 8, 16, 32]]
self.assertEqual(param_w.values, expected_values_w)
self.assertFalse(param_w.log_scale)
self.assertTrue(param_w.is_fidelity)
self.assertEqual(param_w.target_value, math.log10(32))

@mock_botorch_optimize
def test_log_scale_choice_with_adapter(self) -> None:
search_space = SearchSpace(
parameters=[
ChoiceParameter(
"z",
parameter_type=ParameterType.FLOAT,
values=[1.0, 10.0, 100.0, 1000.0],
),
ChoiceParameter(
"w",
parameter_type=ParameterType.INT,
values=[2, 4, 8, 16, 32],
),
]
)
experiment = get_experiment_with_observations(
observations=[[1.0], [2.0], [3.0]], search_space=search_space
)
generator = BoTorchGenerator()
adapter = TorchAdapter(
experiment=experiment,
generator=generator,
transforms=[ChoiceToNumericChoice, Log],
)
gr = adapter.gen(n=1)
self.assertEqual(len(gr.arms), 1)
# Check the SSD to see if the parameters are log-transformed correctly.
ssd = none_throws(generator.surrogate._last_search_space_digest)
self.assertEqual(ssd.feature_names, ["z", "w"])
self.assertEqual(
ssd.discrete_choices,
{
0: [
math.log10(1.0),
math.log10(10.0),
math.log10(100.0),
math.log10(1000.0),
],
1: [
math.log10(2),
math.log10(4),
math.log10(8),
math.log10(16),
math.log10(32),
],
},
)
Loading