Skip to content

Commit ad38736

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Support NaN max_reference_point in infer_reference_point (#1671)
Summary: Pull Request resolved: #1671 This helps support partial objective thresholds in Ax. Reviewed By: Balandat Differential Revision: D43210613 fbshipit-source-id: cbab579316994ef9246b40dd4d882b261a17f00c
1 parent a8efd76 commit ad38736

File tree

2 files changed

+100
-18
lines changed

2 files changed

+100
-18
lines changed

botorch/utils/multi_objective/hypervolume.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@ def infer_reference_point(
3939
) -> Tensor:
4040
r"""Get reference point for hypervolume computations.
4141
42-
This sets the reference point to be `ref_point = nadir - 0.1 * range`
43-
when there is no pareto_Y that is better than the reference point.
42+
This sets the reference point to be `ref_point = nadir - scale * range`
43+
when there is no `pareto_Y` that is better than `max_ref_point`.
44+
If there's `pareto_Y` better than `max_ref_point`, the reference point
45+
will be set to `max_ref_point - scale * range` if `scale_max_ref_point`
46+
is true and to `max_ref_point` otherwise.
4447
4548
[Ishibuchi2011]_ find 0.1 to be a robust multiplier for scaling the
4649
nadir point.
@@ -50,6 +53,9 @@ def infer_reference_point(
5053
Args:
5154
pareto_Y: A `n x m`-dim tensor of Pareto-optimal points.
5255
max_ref_point: A `m` dim tensor indicating the maximum reference point.
56+
Some elements can be NaN, except when `pareto_Y` is empty,
57+
in which case these dimensions will be treated as if no
58+
`max_ref_point` was provided and set to `nadir - scale * range`.
5359
scale: A multiplier used to scale back the reference point based on the
5460
range of each objective.
5561
scale_max_ref_point: A boolean indicating whether to apply scaling to
@@ -58,20 +64,28 @@ def infer_reference_point(
5864
Returns:
5965
A `m`-dim tensor containing the reference point.
6066
"""
61-
6267
if pareto_Y.shape[0] == 0:
6368
if max_ref_point is None:
6469
raise BotorchError("Empty pareto set and no max ref point provided")
70+
if max_ref_point.isnan().any():
71+
raise BotorchError("Empty pareto set and max ref point includes NaN.")
6572
if scale_max_ref_point:
6673
return max_ref_point - scale * max_ref_point.abs()
6774
return max_ref_point
6875
if max_ref_point is not None:
69-
better_than_ref = (pareto_Y > max_ref_point).all(dim=-1)
76+
non_nan_idx = ~max_ref_point.isnan()
77+
# Count all points exceeding non-NaN reference point as being better.
78+
better_than_ref = (pareto_Y[:, non_nan_idx] > max_ref_point[non_nan_idx]).all(
79+
dim=-1
80+
)
7081
else:
71-
better_than_ref = torch.full(
72-
pareto_Y.shape[:1], 1, dtype=bool, device=pareto_Y.device
82+
non_nan_idx = torch.ones(
83+
pareto_Y.shape[-1], dtype=torch.bool, device=pareto_Y.device
7384
)
74-
if max_ref_point is not None and better_than_ref.any():
85+
better_than_ref = torch.ones(
86+
pareto_Y.shape[:1], dtype=torch.bool, device=pareto_Y.device
87+
)
88+
if max_ref_point is not None and better_than_ref.any() and non_nan_idx.all():
7589
Y_range = pareto_Y[better_than_ref].max(dim=0).values - max_ref_point
7690
if scale_max_ref_point:
7791
return max_ref_point - scale * Y_range
@@ -80,17 +94,28 @@ def infer_reference_point(
8094
# no points better than max_ref_point and only a single observation
8195
# subtract MIN_Y_RANGE to handle the case that pareto_Y is a singleton
8296
# with objective value of 0.
83-
return (pareto_Y - scale * pareto_Y.abs().clamp_min(MIN_Y_RANGE)).view(-1)
84-
# no points better than max_ref_point and multiple observations
85-
# make sure that each dimension of the nadir point is no greater than
86-
# the max_ref_point
87-
nadir = pareto_Y.min(dim=0).values
88-
if max_ref_point is not None:
89-
nadir = torch.min(nadir, max_ref_point)
90-
ideal = pareto_Y.max(dim=0).values
91-
# handle case where all values for one objective are the same
92-
Y_range = (ideal - nadir).clamp_min(MIN_Y_RANGE)
93-
return nadir - scale * Y_range
97+
Y_range = pareto_Y.abs().clamp_min(MIN_Y_RANGE).view(-1)
98+
ref_point = pareto_Y.view(-1) - scale * Y_range
99+
else:
100+
# no points better than max_ref_point and multiple observations
101+
# make sure that each dimension of the nadir point is no greater than
102+
# the max_ref_point
103+
nadir = pareto_Y.min(dim=0).values
104+
if max_ref_point is not None:
105+
nadir[non_nan_idx] = torch.min(
106+
nadir[non_nan_idx], max_ref_point[non_nan_idx]
107+
)
108+
ideal = pareto_Y.max(dim=0).values
109+
# handle case where all values for one objective are the same
110+
Y_range = (ideal - nadir).clamp_min(MIN_Y_RANGE)
111+
ref_point = nadir - scale * Y_range
112+
# Set not-nan indices - if any - to max_ref_point.
113+
if non_nan_idx.any() and not non_nan_idx.all() and better_than_ref.any():
114+
if scale_max_ref_point:
115+
ref_point[non_nan_idx] = (max_ref_point - scale * Y_range)[non_nan_idx]
116+
else:
117+
ref_point[non_nan_idx] = max_ref_point[non_nan_idx]
118+
return ref_point
94119

95120

96121
class Hypervolume:

test/utils/multi_objective/test_hypervolume.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,60 @@ def test_infer_reference_point(self):
243243
ref_point = infer_reference_point(pareto_Y=Y, scale=0.2)
244244
self.assertAllClose(ref_point, expected_ref_point)
245245
ref_point = infer_reference_point(pareto_Y=Y)
246+
expected_ref_point = nadir - 0.1 * (ideal - nadir)
247+
self.assertAllClose(ref_point, expected_ref_point)
248+
249+
# Test all NaN max_ref_point.
250+
ref_point = infer_reference_point(
251+
pareto_Y=Y,
252+
max_ref_point=torch.tensor([float("nan"), float("nan")], **tkwargs),
253+
)
254+
self.assertAllClose(ref_point, expected_ref_point)
255+
# Test partial NaN, partial worse than nadir.
256+
expected_ref_point = nadir.clone()
257+
expected_ref_point[1] = -1e5
258+
ref_point = infer_reference_point(
259+
pareto_Y=Y,
260+
max_ref_point=torch.tensor([float("nan"), -1e5], **tkwargs),
261+
scale=0.0,
262+
)
263+
self.assertAllClose(ref_point, expected_ref_point)
264+
# Test partial NaN, partial better than nadir.
265+
expected_ref_point = nadir
266+
ref_point = infer_reference_point(
267+
pareto_Y=Y,
268+
max_ref_point=torch.tensor([float("nan"), 1e5], **tkwargs),
269+
scale=0.0,
270+
)
271+
self.assertAllClose(ref_point, expected_ref_point)
272+
# Test partial NaN, partial worse than nadir with scale_max_ref_point.
273+
expected_ref_point[1] = -1e5
274+
expected_ref_point = expected_ref_point - 0.2 * (ideal - expected_ref_point)
275+
ref_point = infer_reference_point(
276+
pareto_Y=Y,
277+
max_ref_point=torch.tensor([float("nan"), -1e5], **tkwargs),
278+
scale=0.2,
279+
scale_max_ref_point=True,
280+
)
281+
self.assertAllClose(ref_point, expected_ref_point)
282+
# Test with single point in Pareto_Y, worse than ref point.
283+
ref_point = infer_reference_point(
284+
pareto_Y=Y[:1],
285+
max_ref_point=torch.tensor([float("nan"), 1e5], **tkwargs),
286+
)
287+
expected_ref_point = Y[0] - 0.1 * Y[0].abs()
288+
self.assertTrue(torch.equal(expected_ref_point, ref_point))
289+
# Test with single point in Pareto_Y, better than ref point.
290+
ref_point = infer_reference_point(
291+
pareto_Y=Y[:1],
292+
max_ref_point=torch.tensor([float("nan"), -1e5], **tkwargs),
293+
scale_max_ref_point=True,
294+
)
295+
expected_ref_point[1] = -1e5 - 0.1 * Y[0, 1].abs()
296+
self.assertTrue(torch.equal(expected_ref_point, ref_point))
297+
# Empty pareto_Y with nan ref point.
298+
with self.assertRaisesRegex(BotorchError, "ref point includes NaN"):
299+
ref_point = infer_reference_point(
300+
pareto_Y=Y[:0],
301+
max_ref_point=torch.tensor([float("nan"), -1e5], **tkwargs),
302+
)

0 commit comments

Comments
 (0)