@@ -39,8 +39,11 @@ def infer_reference_point(
39
39
) -> Tensor :
40
40
r"""Get reference point for hypervolume computations.
41
41
42
- This sets the reference point to be `ref_point = nadir - 0.1 * range`
43
- when there is no pareto_Y that is better than the reference point.
42
+ This sets the reference point to be `ref_point = nadir - scale * range`
43
+ when there is no `pareto_Y` that is better than `max_ref_point`.
44
+ If there's `pareto_Y` better than `max_ref_point`, the reference point
45
+ will be set to `max_ref_point - scale * range` if `scale_max_ref_point`
46
+ is true and to `max_ref_point` otherwise.
44
47
45
48
[Ishibuchi2011]_ find 0.1 to be a robust multiplier for scaling the
46
49
nadir point.
@@ -50,6 +53,9 @@ def infer_reference_point(
50
53
Args:
51
54
pareto_Y: A `n x m`-dim tensor of Pareto-optimal points.
52
55
max_ref_point: A `m` dim tensor indicating the maximum reference point.
56
+ Some elements can be NaN, except when `pareto_Y` is empty,
57
+ in which case these dimensions will be treated as if no
58
+ `max_ref_point` was provided and set to `nadir - scale * range`.
53
59
scale: A multiplier used to scale back the reference point based on the
54
60
range of each objective.
55
61
scale_max_ref_point: A boolean indicating whether to apply scaling to
@@ -58,20 +64,28 @@ def infer_reference_point(
58
64
Returns:
59
65
A `m`-dim tensor containing the reference point.
60
66
"""
61
-
62
67
if pareto_Y .shape [0 ] == 0 :
63
68
if max_ref_point is None :
64
69
raise BotorchError ("Empty pareto set and no max ref point provided" )
70
+ if max_ref_point .isnan ().any ():
71
+ raise BotorchError ("Empty pareto set and max ref point includes NaN." )
65
72
if scale_max_ref_point :
66
73
return max_ref_point - scale * max_ref_point .abs ()
67
74
return max_ref_point
68
75
if max_ref_point is not None :
69
- better_than_ref = (pareto_Y > max_ref_point ).all (dim = - 1 )
76
+ non_nan_idx = ~ max_ref_point .isnan ()
77
+ # Count all points exceeding non-NaN reference point as being better.
78
+ better_than_ref = (pareto_Y [:, non_nan_idx ] > max_ref_point [non_nan_idx ]).all (
79
+ dim = - 1
80
+ )
70
81
else :
71
- better_than_ref = torch .full (
72
- pareto_Y .shape [: 1 ], 1 , dtype = bool , device = pareto_Y .device
82
+ non_nan_idx = torch .ones (
83
+ pareto_Y .shape [- 1 ], dtype = torch . bool , device = pareto_Y .device
73
84
)
74
- if max_ref_point is not None and better_than_ref .any ():
85
+ better_than_ref = torch .ones (
86
+ pareto_Y .shape [:1 ], dtype = torch .bool , device = pareto_Y .device
87
+ )
88
+ if max_ref_point is not None and better_than_ref .any () and non_nan_idx .all ():
75
89
Y_range = pareto_Y [better_than_ref ].max (dim = 0 ).values - max_ref_point
76
90
if scale_max_ref_point :
77
91
return max_ref_point - scale * Y_range
@@ -80,17 +94,28 @@ def infer_reference_point(
80
94
# no points better than max_ref_point and only a single observation
81
95
# subtract MIN_Y_RANGE to handle the case that pareto_Y is a singleton
82
96
# with objective value of 0.
83
- return (pareto_Y - scale * pareto_Y .abs ().clamp_min (MIN_Y_RANGE )).view (- 1 )
84
- # no points better than max_ref_point and multiple observations
85
- # make sure that each dimension of the nadir point is no greater than
86
- # the max_ref_point
87
- nadir = pareto_Y .min (dim = 0 ).values
88
- if max_ref_point is not None :
89
- nadir = torch .min (nadir , max_ref_point )
90
- ideal = pareto_Y .max (dim = 0 ).values
91
- # handle case where all values for one objective are the same
92
- Y_range = (ideal - nadir ).clamp_min (MIN_Y_RANGE )
93
- return nadir - scale * Y_range
97
+ Y_range = pareto_Y .abs ().clamp_min (MIN_Y_RANGE ).view (- 1 )
98
+ ref_point = pareto_Y .view (- 1 ) - scale * Y_range
99
+ else :
100
+ # no points better than max_ref_point and multiple observations
101
+ # make sure that each dimension of the nadir point is no greater than
102
+ # the max_ref_point
103
+ nadir = pareto_Y .min (dim = 0 ).values
104
+ if max_ref_point is not None :
105
+ nadir [non_nan_idx ] = torch .min (
106
+ nadir [non_nan_idx ], max_ref_point [non_nan_idx ]
107
+ )
108
+ ideal = pareto_Y .max (dim = 0 ).values
109
+ # handle case where all values for one objective are the same
110
+ Y_range = (ideal - nadir ).clamp_min (MIN_Y_RANGE )
111
+ ref_point = nadir - scale * Y_range
112
+ # Set not-nan indices - if any - to max_ref_point.
113
+ if non_nan_idx .any () and not non_nan_idx .all () and better_than_ref .any ():
114
+ if scale_max_ref_point :
115
+ ref_point [non_nan_idx ] = (max_ref_point - scale * Y_range )[non_nan_idx ]
116
+ else :
117
+ ref_point [non_nan_idx ] = max_ref_point [non_nan_idx ]
118
+ return ref_point
94
119
95
120
96
121
class Hypervolume :
0 commit comments