@@ -119,13 +119,15 @@ def __init__(
119119 self ,
120120 noise_std : None | float | list [float ] = None ,
121121 negate : bool = False ,
122+ dtype : torch .dtype = torch .double ,
122123 ) -> None :
123124 r"""
124125 Args:
125126 noise_std: Standard deviation of the observation noise.
126127 negate: If True, negate the objectives.
128+ dtype: The dtype that is used for the bounds of the function.
127129 """
128- super ().__init__ (noise_std = noise_std , negate = negate )
130+ super ().__init__ (noise_std = noise_std , negate = negate , dtype = dtype )
129131 self ._branin = Branin ()
130132
131133 def _rescaled_branin (self , X : Tensor ) -> Tensor :
@@ -179,12 +181,14 @@ def __init__(
179181 dim : int ,
180182 noise_std : None | float | list [float ] = None ,
181183 negate : bool = False ,
184+ dtype : torch .dtype = torch .double ,
182185 ) -> None :
183186 r"""
184187 Args:
185188 dim: The (input) dimension.
186189 noise_std: Standard deviation of the observation noise.
187190 negate: If True, negate the function.
191+ dtype: The dtype that is used for the bounds of the function.
188192 """
189193 if dim < self ._min_dim :
190194 raise ValueError (f"dim must be >= { self ._min_dim } , but got dim={ dim } !" )
@@ -194,7 +198,7 @@ def __init__(
194198 ]
195199 # max_hv is the area of the box minus the area of the curve formed by the PF.
196200 self ._max_hv = self ._ref_point [0 ] * self ._ref_point [1 ] - self ._area_under_curve
197- super ().__init__ (noise_std = noise_std , negate = negate )
201+ super ().__init__ (noise_std = noise_std , negate = negate , dtype = dtype )
198202
199203 @abstractmethod
200204 def _h (self , X : Tensor ) -> Tensor :
@@ -339,13 +343,15 @@ def __init__(
339343 num_objectives : int = 2 ,
340344 noise_std : None | float | list [float ] = None ,
341345 negate : bool = False ,
346+ dtype : torch .dtype = torch .double ,
342347 ) -> None :
343348 r"""
344349 Args:
345350 dim: The (input) dimension of the function.
346351 num_objectives: Must be less than dim.
347352 noise_std: Standard deviation of the observation noise.
348353 negate: If True, negate the function.
354+ dtype: The dtype that is used for the bounds of the function.
349355 """
350356 if dim <= num_objectives :
351357 raise ValueError (
@@ -356,7 +362,7 @@ def __init__(
356362 self .k = self .dim - self .num_objectives + 1
357363 self ._bounds = [(0.0 , 1.0 ) for _ in range (self .dim )]
358364 self ._ref_point = [self ._ref_val for _ in range (num_objectives )]
359- super ().__init__ (noise_std = noise_std , negate = negate )
365+ super ().__init__ (noise_std = noise_std , negate = negate , dtype = dtype )
360366
361367
362368class DTLZ1 (DTLZ ):
@@ -608,12 +614,14 @@ def __init__(
608614 noise_std : None | float | list [float ] = None ,
609615 negate : bool = False ,
610616 num_objectives : int = 2 ,
617+ dtype : torch .dtype = torch .double ,
611618 ) -> None :
612619 r"""
613620 Args:
614621 noise_std: Standard deviation of the observation noise.
615622 negate: If True, negate the objectives.
616623 num_objectives: The number of objectives.
624+ dtype: The dtype that is used for the bounds of the function.
617625 """
618626 if num_objectives not in (2 , 3 , 4 ):
619627 raise UnsupportedError ("GMM only currently supports 2 to 4 objectives." )
@@ -623,7 +631,7 @@ def __init__(
623631 if num_objectives > 3 :
624632 self ._ref_point .append (- 0.1866 )
625633 self .num_objectives = num_objectives
626- super ().__init__ (noise_std = noise_std , negate = negate )
634+ super ().__init__ (noise_std = noise_std , negate = negate , dtype = dtype )
627635 gmm_pos = torch .tensor (
628636 [
629637 [[0.2 , 0.2 ], [0.8 , 0.2 ], [0.5 , 0.7 ]],
@@ -935,13 +943,15 @@ def __init__(
935943 num_objectives : int = 2 ,
936944 noise_std : None | float | list [float ] = None ,
937945 negate : bool = False ,
946+ dtype : torch .dtype = torch .double ,
938947 ) -> None :
939948 r"""
940949 Args:
941950 dim: The (input) dimension of the function.
942951 num_objectives: Number of objectives. Must not be larger than dim.
943952 noise_std: Standard deviation of the observation noise.
944953 negate: If True, negate the function.
954+ dtype: The dtype that is used for the bounds of the function.
945955 """
946956 if num_objectives != 2 :
947957 raise NotImplementedError (
@@ -954,7 +964,7 @@ def __init__(
954964 self .num_objectives = num_objectives
955965 self .dim = dim
956966 self ._bounds = [(0.0 , 1.0 ) for _ in range (self .dim )]
957- super ().__init__ (noise_std = noise_std , negate = negate )
967+ super ().__init__ (noise_std = noise_std , negate = negate , dtype = dtype )
958968
959969 @staticmethod
960970 def _g (X : Tensor ) -> Tensor :
@@ -1246,15 +1256,17 @@ def __init__(
12461256 noise_std : None | float | list [float ] = None ,
12471257 constraint_noise_std : None | float | list [float ] = None ,
12481258 negate : bool = False ,
1259+ dtype : torch .dtype = torch .double ,
12491260 ) -> None :
12501261 r"""
12511262 Args:
12521263 noise_std: Standard deviation of the observation noise of the objectives.
12531264 constraint_noise_std: Standard deviation of the observation noise of the
12541265 constraint.
12551266 negate: If True, negate the function.
1267+ dtype: The dtype that is used for the bounds of the function.
12561268 """
1257- super ().__init__ (noise_std = noise_std , negate = negate )
1269+ super ().__init__ (noise_std = noise_std , negate = negate , dtype = dtype )
12581270 con_bounds = torch .tensor (self ._con_bounds , dtype = self .bounds .dtype ).transpose (
12591271 - 1 , - 2
12601272 )
@@ -1357,6 +1369,7 @@ def __init__(
13571369 noise_std : None | float | list [float ] = None ,
13581370 constraint_noise_std : None | float | list [float ] = None ,
13591371 negate : bool = False ,
1372+ dtype : torch .dtype = torch .double ,
13601373 ) -> None :
13611374 r"""
13621375 Args:
@@ -1365,12 +1378,13 @@ def __init__(
13651378 constraint_noise_std: Standard deviation of the observation noise of the
13661379 constraints.
13671380 negate: If True, negate the function.
1381+ dtype: The dtype that is used for the bounds of the function.
13681382 """
13691383 if dim < 2 :
13701384 raise ValueError ("dim must be greater than or equal to 2." )
13711385 self .dim = dim
13721386 self ._bounds = [(0.0 , 1.0 ) for _ in range (self .dim )]
1373- super ().__init__ (noise_std = noise_std , negate = negate )
1387+ super ().__init__ (noise_std = noise_std , negate = negate , dtype = dtype )
13741388 self .constraint_noise_std = constraint_noise_std
13751389
13761390 def LA2 (self , A , B , C , D , theta ) -> Tensor :
0 commit comments