How to specify perturbation function #2575
-
| I want to specify a function for the  
 It works fine on training the model. However, when optimizing the acquisition function, instead of passing an input of shape  For context, this is the function I am using for  def perturb_input(x):
    # draw_sobol_normal_samples returns size (n, d)
    # Draw random samples and expand batch dimension
    x_perturb = draw_sobol_normal_samples(d=problem_dim, n=N_W, **tkwargs)[None, :, :]
    # size (batch, n_p, d)
    x_perturb = x_perturb.repeat(x.shape[0], 1, 1)
    # perturbn dim0
    dim0_std = 0.05 * x[..., 0:1] / 2
    x_perturb[:, :, 0] *= dim0_std
    # perturb dim1
    x_perturb[:, :, 1] *= 0.01
    return x_perturbHere is the training function (mostly taken from the tutorial) and optimization function: bounds = torch.stack([torch.zeros(problem_dim), torch.ones(problem_dim)]).to(**tkwargs)
def train_model(train_X: Tensor, train_Y: Tensor) -> SingleTaskGP:
    r"""Returns a `SingleTaskGP` model trained on the inputs"""
    intf = InputPerturbation(
        # perturbation_set=draw_sobol_normal_samples(d=problem_dim, n=N_W, **tkwargs) * STD_DEV,
        perturbation_set=perturb_input,
        bounds=bounds,
    )
    model = SingleTaskGP(
        train_X, train_Y, input_transform=intf, outcome_transform=Standardize(m=1)
    )
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    fit_gpytorch_mll(mll)
    return model
risk_measure = VaR(alpha=ALPHA, n_w=N_W)
def optimize_acqf_and_get_observation():
    r"""Optimizes the acquisition function, and returns a new candidate and observation."""
    acqf = qNoisyExpectedImprovement(
        model=model,
        X_baseline=train_X,
        sampler=SobolQMCNormalSampler(sample_shape=torch.Size([128])),
        objective=risk_measure,
        prune_baseline=True,
    )
    candidate, _ = optimize_acqf(
        acq_function=acqf,
        bounds=bounds,
        q=BATCH_SIZE,
        num_restarts=NUM_RESTARTS,
        raw_samples=RAW_SAMPLES,
    )
    new_observations = fitness_fun(candidate)
    return candidate, new_observationsI would expect the input shape to  Here is the full error: ---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[15], line 6
      4 print(f"Starting iteration {i}, total time: {time() - start_time:.3f} seconds.")
      5 # optimize the acquisition function and get the observations
----> 6 candidate, observations = optimize_acqf_and_get_observation()
      8 # update the model with new observations
      9 train_X = torch.cat([train_X, candidate], dim=0)
Cell In[14], line 13
      4 r"""Optimizes the acquisition function, and returns a new candidate and observation."""
      5 acqf = qNoisyExpectedImprovement(
      6     model=model,
      7     X_baseline=train_X,
   (...)
     10     prune_baseline=True,
     11 )
---> 13 candidate, _ = optimize_acqf(
     14     acq_function=acqf,
     15     bounds=bounds,
     16     q=BATCH_SIZE,
     17     num_restarts=NUM_RESTARTS,
     18     raw_samples=RAW_SAMPLES,
     19 )
     21 new_observations = fitness_fun(candidate)
     22 return candidate, new_observations
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\optim\optimize.py:567, in optimize_acqf(acq_function, bounds, q, num_restarts, raw_samples, options, inequality_constraints, equality_constraints, nonlinear_inequality_constraints, fixed_features, post_processing_func, batch_initial_conditions, return_best_only, gen_candidates, sequential, ic_generator, timeout_sec, return_full_tree, retry_on_optimization_warning, **ic_gen_kwargs)
    544     gen_candidates = gen_candidates_scipy
    545 opt_acqf_inputs = OptimizeAcqfInputs(
    546     acq_function=acq_function,
    547     bounds=bounds,
   (...)
    565     ic_gen_kwargs=ic_gen_kwargs,
    566 )
--> 567 return _optimize_acqf(opt_acqf_inputs)
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\optim\optimize.py:588, in _optimize_acqf(opt_inputs)
    585     return _optimize_acqf_sequential_q(opt_inputs=opt_inputs)
    587 # Batch optimization (including the case q=1)
--> 588 return _optimize_acqf_batch(opt_inputs=opt_inputs)
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\optim\optimize.py:275, in _optimize_acqf_batch(opt_inputs)
    272     batch_initial_conditions = opt_inputs.batch_initial_conditions
    273 else:
    274     # pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
--> 275     batch_initial_conditions = opt_inputs.get_ic_generator()(
    276         acq_function=opt_inputs.acq_function,
    277         bounds=opt_inputs.bounds,
    278         q=opt_inputs.q,
    279         num_restarts=opt_inputs.num_restarts,
    280         raw_samples=opt_inputs.raw_samples,
    281         fixed_features=opt_inputs.fixed_features,
    282         options=options,
    283         inequality_constraints=opt_inputs.inequality_constraints,
    284         equality_constraints=opt_inputs.equality_constraints,
    285         **opt_inputs.ic_gen_kwargs,
    286     )
    288 batch_limit: int = options.get(
    289     "batch_limit",
    290     (
   (...)
    294     ),
    295 )
    297 def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\optim\initializers.py:417, in gen_batch_initial_conditions(acq_function, bounds, q, num_restarts, raw_samples, fixed_features, options, inequality_constraints, equality_constraints, generator, fixed_X_fantasies)
    415 while start_idx < X_rnd.shape[0]:
    416     end_idx = min(start_idx + batch_limit, X_rnd.shape[0])
--> 417     Y_rnd_curr = acq_function(
    418         X_rnd[start_idx:end_idx].to(device=device)
    419     ).cpu()
    420     Y_rnd_list.append(Y_rnd_curr)
    421     start_idx += batch_limit
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\torch\nn\modules\module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\torch\nn\modules\module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\utils\transforms.py:305, in concatenate_pending_points.<locals>.decorated(cls, X, **kwargs)
    303 if cls.X_pending is not None:
    304     X = torch.cat([X, match_batch_shape(cls.X_pending, X)], dim=-2)
--> 305 return method(cls, X, **kwargs)
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\utils\transforms.py:259, in t_batch_mode_transform.<locals>.decorator.<locals>.decorated(acqf, X, *args, **kwargs)
    257 # add t-batch dim
    258 X = X if X.dim() > 2 else X.unsqueeze(0)
--> 259 output = method(acqf, X, *args, **kwargs)
    260 if hasattr(acqf, "model") and is_ensemble(acqf.model):
    261     # IDEA: this could be wrapped into SampleReducingMCAcquisitionFunction
    262     output = (
    263         output.mean(dim=-1) if not acqf._log else logmeanexp(output, dim=-1)
    264     )
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\acquisition\monte_carlo.py:274, in SampleReducingMCAcquisitionFunction.forward(self, X)
    254 @concatenate_pending_points
    255 @t_batch_mode_transform()
    256 def forward(self, X: Tensor) -> Tensor:
    257     r"""Computes the acquisition value associated with the input `X`. Weighs the
    258     acquisition utility values by smoothed constraint indicators if `constraints`
    259     was passed to the constructor of the class. Applies `self.sample_reduction` and
   (...)
    272         batch shape of model and input `X`.
    273     """
--> 274     non_reduced_acqval = self._non_reduced_forward(X=X)
    275     return self._sample_reduction(self._q_reduction(non_reduced_acqval))
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\acquisition\monte_carlo.py:287, in SampleReducingMCAcquisitionFunction._non_reduced_forward(self, X)
    277 def _non_reduced_forward(self, X: Tensor) -> Tensor:
    278     """Compute the constrained acquisition values at the MC-sample, q level.
    279 
    280     Args:
   (...)
    285         A Tensor with shape `sample_sample x batch_shape x q`.
    286     """
--> 287     samples, obj = self._get_samples_and_objectives(X)
    288     samples = repeat_to_match_aug_dim(target_tensor=samples, reference_tensor=obj)
    289     acqval = self._sample_forward(obj)  # `sample_sample x batch_shape x q`
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\acquisition\monte_carlo.py:597, in qNoisyExpectedImprovement._get_samples_and_objectives(self, X)
    594 X_full = torch.cat([match_batch_shape(self.X_baseline, X), X], dim=-2)
    595 # TODO: Implement more efficient way to compute posterior over both training and
    596 # test points in GPyTorch (https://github.com/cornellius-gp/gpytorch/issues/567)
--> 597 posterior = self.model.posterior(
    598     X_full, posterior_transform=self.posterior_transform
    599 )
    600 if not self._cache_root:
    601     samples_full = super().get_posterior_samples(posterior)
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\models\gpytorch.py:385, in BatchedMultiOutputGPyTorchModel.posterior(self, X, output_indices, observation_noise, posterior_transform, **kwargs)
    382 self.eval()  # make sure model is in eval mode
    383 # input transforms are applied at `posterior` in `eval` mode, and at
    384 # `model.forward()` at the training time
--> 385 X = self.transform_inputs(X)
    386 with gpt_posterior_settings():
    387     # insert a dimension for the output dimension
    388     if self._num_outputs > 1:
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\models\model.py:229, in Model.transform_inputs(self, X, input_transform)
    227     return input_transform(X)
    228 try:
--> 229     return self.input_transform(X)
    230 except AttributeError:
    231     return X
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\torch\nn\modules\module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\torch\nn\modules\module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\models\transforms\input.py:78, in InputTransform.forward(self, X)
     76 elif self.transform_on_eval:
     77     if fantasize.off() or self.transform_on_fantasize:
---> 78         return self.transform(X)
     79 return X
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\models\transforms\input.py:1417, in InputPerturbation.transform(self, X)
   1395 r"""Transform the inputs by adding `perturbation_set` to each input.
   1396 
   1397 For each `1 x d`-dim element in the input tensor, this will produce
   (...)
   1411     A `batch_shape x (q * n_p) x d`-dim tensor of perturbed inputs.
   1412 """
   1413 # NOTE: If we had access to n_p without evaluating _perturbations when the
   1414 # perturbation_set is a function, we could move this into `_transform`.
   1415 # Further, we could remove the two `transpose` calls below if one were
   1416 # willing to accept a different ordering of the transformed output.
-> 1417 self._perturbations = self._expanded_perturbations(X)
   1418 # make space for n_p dimension, switch n_p with n after transform, and flatten.
   1419 return self._transform(X.unsqueeze(-3)).transpose(-3, -2).flatten(-3, -2)
File c:\Users\username\AppData\Local\anaconda3\envs\envname\Lib\site-packages\botorch\models\transforms\input.py:1444, in InputPerturbation._expanded_perturbations(self, X)
   1442     p = p.expand(X.shape[-2], *p.shape)  # p is batch_shape x n x n_p x d
   1443 else:
-> 1444     p = p(X) if self.indices is None else p(X[..., self.indices])
   1445 return p.transpose(-3, -2)
Cell In[10], line 56
     54 # perturbn dim0
     55 dim0_std = 0.05 * x[..., 0:1] / 2
---> 56 x_perturb[:, :, 0] *= dim0_std
     58 # perturb dim1
     59 x_perturb[:, :, 1] *= 0.01
RuntimeError: The size of tensor a (128) must match the size of tensor b (5) at non-singleton dimension 1 | 
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
| Hi @samuelkim16. 
 In BoTorch, we use batch evaluations within acquisition functions. 
 The  I am also avoiding in-place modification of tensors here since it can lead to issues with autograd. | 
Beta Was this translation helpful? Give feedback.
Hi @samuelkim16.
In BoTorch, we use batch evaluations within acquisition functions.
The
batch x dhere refers to an arbitrary batch shape, which in your case is128 x 5. You should updateperturb_inputto support batch inputs. Here's an (untested) example o…