Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 81 additions & 15 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,8 @@ def __init__(
Args:
name: Name of the parameter.
parameter_type: Enum indicating the type of parameter value. Expects
"float", or "int". "bool" and "str" are not supported.
"float", or "int". "bool" and "str" are supported only for simple
copies (expression_str must be a single parameter name).
expression_str: A string expression of the derived parameter definition.
is_fidelity: Whether this parameter is a fidelity parameter.
target_value: Target value of this parameter if it is a fidelity.
Expand All @@ -1336,18 +1337,15 @@ def __init__(
raise UnsupportedError(
"Derived parameters do not support specifying a target value."
)
elif parameter_type not in (ParameterType.FLOAT, ParameterType.INT):
raise UserInputError(
"Derived parameters must be of type float or int, but got "
f"{parameter_type}."
)

self.set_expression_str(expression_str=expression_str)
self._name = name
self._parameter_type = parameter_type
self._parameter_type = parameter_type # Set first so validation works
self._is_fidelity = is_fidelity
self._target_value = target_value

# Parse expression and validate type constraint (reuses set_expression_str)
self.set_expression_str(expression_str)

def _parse_expression_str(self, expression_str: str) -> None:
"""Parse the expression str into parameter names and coefficients.

Expand All @@ -1362,6 +1360,8 @@ def _parse_expression_str(self, expression_str: str) -> None:
elif not isinstance(expression, (Add, Mul, Symbol)):
raise UnsupportedError("Only linear expressions are currently supported.")
coefficient_dict = expression.as_coefficients_dict()
# NOTE: the constant/intercept term is always stored with the integer 1 as its
# key, representing the "unit monomial" (x^0 = 1).
self._intercept = float(coefficient_dict.pop(1, 0.0))
parameter_names_to_weights = {}
for name, coef in coefficient_dict.items():
Expand All @@ -1375,6 +1375,9 @@ def _parse_expression_str(self, expression_str: str) -> None:
@property
def domain_repr(self) -> str:
"""Returns a string representation of the derived parameter."""
if self._is_simple_copy:
return f"value={self.source_parameter_name}"

terms = [
f"{weight} * {name}"
for name, weight in self._parameter_names_to_weights.items()
Expand All @@ -1391,9 +1394,39 @@ def parameter_names_to_weights(self) -> dict[str, float]:
def expression_str(self) -> str:
return self._expression_str

@property
def _is_simple_copy(self) -> bool:
"""Check if this derived parameter is a simple copy of another parameter.

A simple copy means the expression has exactly one source parameter with
coefficient 1.0 and no intercept (i.e., `derived_param = source_param`).
"""
return (
len(self._parameter_names_to_weights) == 1
and self._intercept == 0.0
and list(self._parameter_names_to_weights.values())[0] == 1.0
)

@property
def source_parameter_name(self) -> str | None:
"""Return the source parameter name if this is a simple copy, else None."""
if self._is_simple_copy:
return list(self._parameter_names_to_weights.keys())[0]
return None

def set_expression_str(self, expression_str: str) -> None:
self._expression_str = expression_str
# Parse expression first to determine if it's a simple copy
self._parse_expression_str(expression_str=expression_str)
# Re-validate: BOOL and STRING only allowed for simple copies
if self._parameter_type not in (ParameterType.FLOAT, ParameterType.INT):
if not self._is_simple_copy:
raise UserInputError(
f"Derived parameters of type {self._parameter_type.name} must be "
"simple copies (expression_str must be a single parameter name "
"with no arithmetic). For expressions with arithmetic, use FLOAT "
"or INT."
)

@property
def intercept(self) -> float:
Expand All @@ -1415,6 +1448,13 @@ def compute(self, parameters: TParameterization) -> TParamValue:
Returns:
The value of the derived parameter.
"""
if self._is_simple_copy:
# Direct copy - works for all parameter types
source_name = self.source_parameter_name
value = parameters[none_throws(source_name)]
return self.cast(value)

# Arithmetic expression - only for numeric types
return self.cast(
self._intercept
+ sum(
Expand Down Expand Up @@ -1452,13 +1492,20 @@ def validate(
)
return False
expected_value = self.compute(parameters=parameters)
is_valid = (
abs(
assert_is_instance(expected_value, TNumeric)
- assert_is_instance(value, TNumeric)

# For numeric types, use epsilon comparison; for others, use equality
if self._parameter_type in (ParameterType.FLOAT, ParameterType.INT):
is_valid = (
abs(
assert_is_instance(expected_value, TNumeric)
- assert_is_instance(value, TNumeric)
)
< EPS
)
< EPS
)
else:
# BOOL and STRING use exact equality
is_valid = expected_value == value

if raises and not is_valid:
raise UserInputError(
f"Value {value} is not equal to the expected derived"
Expand All @@ -1475,7 +1522,21 @@ def compute_array(self, df: PandasDataFrame) -> npt.NDArray:
Returns:
A NumPy array with the computed derived values. Rows with NaN values
for any constituent parameter will have NaN as the computed value.
For non-numeric types (BOOL, STRING), returns an object array.
"""
if self._is_simple_copy:
source_name = none_throws(self.source_parameter_name)
if source_name in df.columns:
# Return as object array for non-numeric types
if self._parameter_type in (ParameterType.BOOL, ParameterType.STRING):
return df[source_name].to_numpy()
return df[source_name].to_numpy(dtype=np.float64, na_value=np.nan)
# Missing column
if self._parameter_type in (ParameterType.FLOAT, ParameterType.INT):
return np.full(len(df), np.nan, dtype=np.float64)
return np.full(len(df), None, dtype=object)

# Arithmetic expression - only for numeric types
computed = np.full(len(df), self._intercept, dtype=np.float64)
for p_name, weight in self._parameter_names_to_weights.items():
if p_name in df.columns:
Expand Down Expand Up @@ -1506,7 +1567,12 @@ def validate_array(
if df is None:
return np.full(len(values), False, dtype=bool)
computed = self.compute_array(df)
return np.abs(values - computed) < EPS

if self._parameter_type in (ParameterType.FLOAT, ParameterType.INT):
return np.abs(values - computed) < EPS
else:
# For BOOL and STRING, use equality
return values == computed

def clone(self) -> DerivedParameter:
return DerivedParameter(
Expand Down
132 changes: 130 additions & 2 deletions ax/core/tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from math import isinf
from typing import cast

import numpy as np
import pandas as pd
from ax.core.parameter import (
_get_parameter_type,
ChoiceParameter,
Expand Down Expand Up @@ -1021,8 +1023,8 @@ def test_invalid_inputs(self) -> None:
for parameter_type in (ParameterType.BOOL, ParameterType.STRING):
with self.assertRaisesRegex(
UserInputError,
"Derived parameters must be of type float or int, but got "
f"{parameter_type}.",
f"Derived parameters of type {parameter_type.name} must be simple "
"copies",
):
DerivedParameter(
name="x",
Expand Down Expand Up @@ -1171,6 +1173,132 @@ def test_cardinality(self) -> None:
):
self.param2.cardinality()

def test_simple_copy(self) -> None:
"""Test simple copy functionality for all parameter types including BOOL, and
STRING.
"""
# Test 1: Simple copy detection - _is_simple_copy returns True for single param
dp_float = DerivedParameter(
name="derived_x",
parameter_type=ParameterType.FLOAT,
expression_str="x",
)
self.assertTrue(dp_float._is_simple_copy)
self.assertEqual(dp_float.source_parameter_name, "x")

# Test 2: _is_simple_copy returns False for expressions with coefficients != 1
dp_scaled = DerivedParameter(
name="derived_scaled",
parameter_type=ParameterType.FLOAT,
expression_str="2 * x",
)
self.assertFalse(dp_scaled._is_simple_copy)
self.assertIsNone(dp_scaled.source_parameter_name)

# Test 3: _is_simple_copy returns False for expressions with intercepts
dp_offset = DerivedParameter(
name="derived_offset",
parameter_type=ParameterType.FLOAT,
expression_str="x + 1",
)
self.assertFalse(dp_offset._is_simple_copy)

# Test 4: _is_simple_copy returns False for multi-param expressions
dp_multi = DerivedParameter(
name="derived_multi",
parameter_type=ParameterType.FLOAT,
expression_str="x + y",
)
self.assertFalse(dp_multi._is_simple_copy)

# Test 5: BOOL derived parameter - compute and validate
dp_bool = DerivedParameter(
name="derived_bool",
parameter_type=ParameterType.BOOL,
expression_str="flag",
)
self.assertTrue(dp_bool._is_simple_copy)
self.assertEqual(dp_bool.compute({"flag": True}), True)
self.assertEqual(dp_bool.compute({"flag": False}), False)
self.assertTrue(dp_bool.validate(True, parameters={"flag": True}))
self.assertTrue(dp_bool.validate(False, parameters={"flag": False}))
self.assertFalse(dp_bool.validate(True, parameters={"flag": False}))

# Test 6: STRING derived parameter - compute and validate
dp_string = DerivedParameter(
name="derived_string",
parameter_type=ParameterType.STRING,
expression_str="category",
)
self.assertTrue(dp_string._is_simple_copy)
self.assertEqual(dp_string.compute({"category": "foo"}), "foo")
self.assertEqual(dp_string.compute({"category": "bar"}), "bar")
self.assertTrue(dp_string.validate("foo", parameters={"category": "foo"}))
self.assertFalse(dp_string.validate("foo", parameters={"category": "bar"}))

# Test 7: domain_repr for simple copy
self.assertEqual(dp_bool.domain_repr, "value=flag")
self.assertEqual(dp_string.domain_repr, "value=category")

# Test 8: Error case - BOOL with non-simple expression
with self.assertRaisesRegex(UserInputError, "simple copies"):
DerivedParameter(
name="bad_bool",
parameter_type=ParameterType.BOOL,
expression_str="2 * flag",
)

# Test 9: Error case - STRING with non-simple expression
with self.assertRaisesRegex(UserInputError, "simple copies"):
DerivedParameter(
name="bad_string",
parameter_type=ParameterType.STRING,
expression_str="cat + 1",
)

# Test 10: compute_array and validate_array for BOOL
df = pd.DataFrame(
{
"flag": [True, False, True],
"other": [False, False, False],
}
)
computed_bool = dp_bool.compute_array(df)
self.assertTrue(np.array_equal(computed_bool, np.array([True, False, True])))
self.assertTrue(
np.array_equal(
dp_bool.validate_array(np.array([True, False, True]), df),
np.array([True, True, True]),
)
)

# Test 11: compute_array and validate_array for STRING
df_str = pd.DataFrame({"category": ["foo", "bar", "baz"]})
computed_str = dp_string.compute_array(df_str)
self.assertTrue(np.array_equal(computed_str, np.array(["foo", "bar", "baz"])))

# Test 12: compute_array for numeric (FLOAT) simple copy with column present
# (Covers line 1533: return df[source_name].to_numpy(dtype=np.float64, ...))
df_float = pd.DataFrame({"x": [1.0, 2.0, 3.0]})
np.testing.assert_array_equal(
dp_float.compute_array(df_float), np.array([1.0, 2.0, 3.0])
)

# Test 13: compute_array for simple copy with missing column
# (Covers lines 1535-1537)
df_missing = pd.DataFrame({"other": [1.0, 2.0]})
self.assertTrue(np.all(np.isnan(dp_float.compute_array(df_missing)))) # numeric
self.assertTrue(
np.all(pd.isna(dp_bool.compute_array(df_missing)))
) # non-numeric

# Test 14: validate_array with df=None
# (Covers line 1568)
np.testing.assert_array_equal(
dp_float.validate_array(np.array([1.0, 2.0]), df=None),
np.array([False, False]),
)


class ParameterEqualityTest(TestCase):
def setUp(self) -> None:
Expand Down