Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ax/adapter/transforms/choice_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ def untransform_observation_features(
obsf.parameters[p_name] = reverse_transform[pval]
return observation_features

def transform_experiment_data(
self, experiment_data: ExperimentData
) -> ExperimentData:
return ExperimentData(
arm_data=experiment_data.arm_data.replace(
to_replace=self.encoded_parameters
),
observation_data=experiment_data.observation_data,
)


class ChoiceEncode(DeprecatedTransformMixin, ChoiceToNumericChoice):
"""Deprecated alias for ChoiceToNumericChoice."""
Expand Down
23 changes: 23 additions & 0 deletions ax/adapter/transforms/fill_missing_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ax.core.observation import Observation, ObservationFeatures
from ax.core.search_space import SearchSpace
from ax.core.types import TParameterization
from ax.exceptions.core import UnsupportedError
from ax.generators.types import TConfig
from pyre_extensions import assert_is_instance, none_throws

Expand Down Expand Up @@ -68,3 +69,25 @@ def transform_observation_features(
}
obsf.parameters.update(fill_params)
return observation_features

def transform_experiment_data(
self, experiment_data: ExperimentData
) -> ExperimentData:
if self.fill_values is None:
return experiment_data
if self.fill_None is False:
# This shouldn't be relevant in regular usage. We add both
# FillMissingParameters and Cast as default transfroms in
# Adapter. Cast will drop parameterizations with missing / None
# values, so not filling None will just lead to it being dropped.
# The exception is added here for completeness.
raise UnsupportedError(
"Transforming `ExperimentData` is not supported for "
"FillMissingParameters with fill_None=False. "
"We cannot distinguish between parameters that are missing "
"and those that are None in `ExperimentData`. "
)
return ExperimentData(
arm_data=experiment_data.arm_data.fillna(value=self.fill_values),
observation_data=experiment_data.observation_data,
)
10 changes: 10 additions & 0 deletions ax/adapter/transforms/logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,13 @@ def untransform_observation_features(
param: float = obsf.parameters[p_name] # pyre-ignore [9]
obsf.parameters[p_name] = expit(param).item()
return observation_features

def transform_experiment_data(
self, experiment_data: ExperimentData
) -> ExperimentData:
arm_data = experiment_data.arm_data
for p_name in self.transform_parameters:
arm_data[p_name] = logit(arm_data[p_name])
return ExperimentData(
arm_data=arm_data, observation_data=experiment_data.observation_data
)
8 changes: 8 additions & 0 deletions ax/adapter/transforms/remove_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,11 @@ def untransform_observation_features(
for p_name, p in self.fixed_parameters.items():
obsf.parameters[p_name] = p.value
return observation_features

def transform_experiment_data(
self, experiment_data: ExperimentData
) -> ExperimentData:
return ExperimentData(
arm_data=experiment_data.arm_data.drop(columns=list(self.fixed_parameters)),
observation_data=experiment_data.observation_data,
)
87 changes: 82 additions & 5 deletions ax/adapter/transforms/tests/test_choice_encode_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from copy import deepcopy

import numpy as np
from ax.adapter.base import DataLoaderConfig
from ax.adapter.data_utils import extract_experiment_data
from ax.adapter.transforms.choice_encode import (
ChoiceToNumericChoice,
OrderedChoiceEncode,
Expand All @@ -20,7 +22,12 @@
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import RobustSearchSpace, SearchSpace
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_robust_search_space
from ax.utils.testing.core_stubs import (
get_experiment_with_observations,
get_robust_search_space,
)
from pandas import DataFrame
from pandas.testing import assert_frame_equal
from pyre_extensions import assert_is_instance


Expand Down Expand Up @@ -66,10 +73,7 @@ def setUp(self) -> None:
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5)
],
)
self.t = self.t_class(
search_space=self.search_space,
observations=[],
)
self.t = self.t_class(search_space=self.search_space)
self.observation_features = [
ObservationFeatures(
parameters={"x": 2.2, "a": 2, "b": 10.0, "c": 10.0, "d": "r"}
Expand Down Expand Up @@ -207,6 +211,79 @@ def test_with_parameter_distributions(self) -> None:
self.assertEqual(rss._environmental_variables, rss_new._environmental_variables)
self.assertEqual(rss_new.parameters["c"].parameter_type, ParameterType.INT)

def test_transform_experiment_data(self) -> None:
parameterizations = [
{"x": 2.2, "a": 2, "b": 10.0, "c": 10.0, "d": "r", "e": "q"},
{"x": 1.0, "a": 1, "b": 1.0, "c": 100.0, "d": "q", "e": "z"},
{"x": 1.2, "a": 2, "b": 100.0, "c": 1000.0, "d": "z", "e": "r"},
]
experiment = get_experiment_with_observations(
observations=[[1.0], [2.0], [3.0]],
search_space=self.search_space,
parameterizations=parameterizations,
)
experiment_data = extract_experiment_data(
experiment=experiment, data_loader_config=DataLoaderConfig()
)
transformed_data = self.t.transform_experiment_data(
experiment_data=deepcopy(experiment_data)
)

# Check that values in arm_data are transformed as expected.
if self.t_class is ChoiceToNumericChoice:
expected_values = zip(
[2.2, 1.0, 1.2],
[2, 1, 2],
normalize_values([10.0, 1.0, 100.0]),
normalize_values([10.0, 100.0, 1000.0]),
[1, 0, 2],
[1, 2, 0],
)
elif self.t_class is OrderedChoiceToIntegerRange:
expected_values = zip(
[2.2, 1.0, 1.2],
[2, 1, 2],
[1.0, 0.0, 2.0],
[0.0, 1.0, 2.0],
["r", "q", "z"],
["q", "z", "r"],
)
else:
raise NotImplementedError
expected_arm_data = DataFrame(
[
{"x": x, "a": a, "b": b, "c": c, "d": d, "e": e}
for x, a, b, c, d, e in expected_values
],
index=experiment_data.arm_data.index,
)
assert_frame_equal(
transformed_data.arm_data.drop(columns="metadata"), expected_arm_data
)

# Check that observation data is unchanged.
assert_frame_equal(
transformed_data.observation_data, experiment_data.observation_data
)

# Test with no parameters transformed.
# Setting `encoded_parameters` directly to simplify testing.
self.t.encoded_parameters = {}
copy_experiment_data = deepcopy(experiment_data)
transformed_data = self.t.transform_experiment_data(
experiment_data=copy_experiment_data
)
# Arm data is same as before but it is not the same object.
assert_frame_equal(transformed_data.arm_data, experiment_data.arm_data)
self.assertIsNot(transformed_data.arm_data, copy_experiment_data.arm_data)
# Observation data is the same object.
assert_frame_equal(
transformed_data.observation_data, experiment_data.observation_data
)
self.assertIs(
transformed_data.observation_data, copy_experiment_data.observation_data
)


class OrderedChoiceToIntegerRangeTransformTest(ChoiceEncodeTransformTest):
t_class = OrderedChoiceToIntegerRange
Expand Down
48 changes: 47 additions & 1 deletion ax/adapter/transforms/tests/test_fill_missing_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@

from copy import deepcopy

from ax.adapter.base import DataLoaderConfig
from ax.adapter.data_utils import extract_experiment_data
from ax.adapter.transforms.fill_missing_parameters import FillMissingParameters

from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UnsupportedError
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment_with_observations
from pandas.testing import assert_frame_equal


class FillMissingParametersTransformTest(TestCase):
Expand Down Expand Up @@ -52,3 +56,45 @@ def test_TransformObservationFeatures(self) -> None:
t = FillMissingParameters(config={})
obs_ft3 = t.transform_observation_features(deepcopy(observation_features))
self.assertEqual(obs_ft3, observation_features)

def test_transform_experiment_data(self) -> None:
parameterizations = [
{"x": 0.0},
{"x": 1.0, "y": 0.0},
{"x": None, "y": None},
]
experiment = get_experiment_with_observations(
observations=[[1.0], [2.0], [3.0]],
parameterizations=parameterizations,
)
experiment_data = extract_experiment_data(
experiment=experiment, data_loader_config=DataLoaderConfig()
)
# Check that arm_data has NaNs as expected.
self.assertEqual(experiment_data.arm_data["x"].isna().sum(), 1)
self.assertEqual(experiment_data.arm_data["y"].isna().sum(), 2)

# Transform and see that NaNs are filled.
t = FillMissingParameters(config={"fill_values": {"x": 2.0, "y": 1.0}})
transformed_data = t.transform_experiment_data(
experiment_data=deepcopy(experiment_data)
)
self.assertEqual(transformed_data.arm_data["x"].tolist(), [0.0, 1.0, 2.0])
self.assertEqual(transformed_data.arm_data["y"].tolist(), [1.0, 0.0, 1.0])
assert_frame_equal(
transformed_data.observation_data, experiment_data.observation_data
)

# Nothing happens if no fill values are given.
t = FillMissingParameters(config={})
transformed_data = t.transform_experiment_data(
experiment_data=deepcopy(experiment_data)
)
self.assertEqual(transformed_data, experiment_data)

# Check for error if fill_None is False.
t = FillMissingParameters(
config={"fill_values": {"x": 2.0, "y": 1.0}, "fill_None": False}
)
with self.assertRaisesRegex(UnsupportedError, "ExperimentData"):
t.transform_experiment_data(experiment_data=experiment_data)
51 changes: 46 additions & 5 deletions ax/adapter/transforms/tests/test_logit_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@

from copy import deepcopy

from ax.adapter.base import DataLoaderConfig
from ax.adapter.data_utils import extract_experiment_data
from ax.adapter.transforms.logit import Logit

from ax.core.observation import ObservationFeatures
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_robust_search_space
from ax.utils.testing.core_stubs import (
get_experiment_with_observations,
get_robust_search_space,
)
from pandas.testing import assert_frame_equal, assert_series_equal
from scipy.special import expit, logit


Expand Down Expand Up @@ -55,9 +60,9 @@ def setUp(self) -> None:
]
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _create_logit_parameter(self, lower, upper, log_scale=False):
def _create_logit_parameter(
self, lower: float, upper: float, log_scale: bool = False
) -> RangeParameter:
return RangeParameter(
"x",
lower=lower,
Expand Down Expand Up @@ -142,3 +147,39 @@ def test_w_parameter_distributions(self) -> None:
)
with self.assertRaisesRegex(UnsupportedError, "transform is not supported"):
t.transform_search_space(rss)

def test_transform_experiment_data(self) -> None:
parameterizations = [
{"x": 0.2, "a": 1, "b": "a"},
{"x": 0.5, "a": 2, "b": "b"},
{"x": 0.7, "a": 3, "b": "c"},
]
experiment = get_experiment_with_observations(
observations=[[1.0], [2.0], [3.0]],
search_space=self.search_space,
parameterizations=parameterizations,
)
experiment_data = extract_experiment_data(
experiment=experiment, data_loader_config=DataLoaderConfig()
)
transformed_data = self.t.transform_experiment_data(
experiment_data=deepcopy(experiment_data)
)

# Check that `x` has been log-transformed.
assert_series_equal(
transformed_data.arm_data["x"], logit(experiment_data.arm_data["x"])
)

# Check that other columns remain unchanged.
assert_series_equal(
transformed_data.arm_data["a"], experiment_data.arm_data["a"]
)
assert_series_equal(
transformed_data.arm_data["b"], experiment_data.arm_data["b"]
)

# Check that observation data is unchanged.
assert_frame_equal(
transformed_data.observation_data, experiment_data.observation_data
)
Loading