-
Notifications
You must be signed in to change notification settings - Fork 358
Description
What happened?
The client.compute_analyses does not output all cards if the model runs with cuda.
If I use the Modular BoTorch Interface and pass the "torch_device":torch.device("cuda" if torch.cuda.is_available() else "CPU") in the model_kwards the optimization runs just fine on the GPU but the compute_analyses fails for some of the cars.
I tracked down the issue to the sobol_measures.py and derivative_measures.py, which generate Tensors in several places and pass them to the model without checking which device the model is running on.
Note that if torch_device = CPU, everything works just fine.
I could hack my way into making it work by first checking the model device and then sending the tensors to the right device, but that might not be the cleanest way to do that.
Any tips on how to proceed cleanly?
Please provide a minimal, reproducible example of the unexpected behavior.
The following lines show how I solved this issue for now. Note that several parts of the code need to be updated that way to get it to run.
def input_function(x: Tensor) -> Tensor:
with torch.no_grad():
means, variances = [], []
# Since we're only looking at mean & variance, we can freely
# use mini-batches.
x = x.to(next(self.model.parameters()).device) # get x to the same device as the model <---- NEW
for x_split in x.split(split_size=mini_batch_size):
p = assert_is_instance(
self.model.posterior(x_split),
GPyTorchPosterior,
)
means.append(p.mean)
variances.append(p.variance)
cat_dim = 1 if is_ensemble(self.model) else 0
return link_function(
torch.cat(means, dim=cat_dim), torch.cat(variances, dim=cat_dim)
)
Please paste any relevant traceback/logs produced by the example provided.
Ax Version
1.0.0
Python Version
3.13.2
Operating System
Ubuntu
(Optional) Describe any potential fixes you've considered to the issue outlined above.
No response
Pull Request
None
Code of Conduct
- I agree to follow Ax's Code of Conduct