Skip to content

[Bug]: Some of the compute_analyses cards fail when using models with cuda #3813

@VMLC-PV

Description

@VMLC-PV

What happened?

The client.compute_analyses does not output all cards if the model runs with cuda.
If I use the Modular BoTorch Interface and pass the "torch_device":torch.device("cuda" if torch.cuda.is_available() else "CPU") in the model_kwards the optimization runs just fine on the GPU but the compute_analyses fails for some of the cars.
I tracked down the issue to the sobol_measures.py and derivative_measures.py, which generate Tensors in several places and pass them to the model without checking which device the model is running on.
Note that if torch_device = CPU, everything works just fine.

I could hack my way into making it work by first checking the model device and then sending the tensors to the right device, but that might not be the cleanest way to do that.

Any tips on how to proceed cleanly?

Please provide a minimal, reproducible example of the unexpected behavior.

The following lines show how I solved this issue for now. Note that several parts of the code need to be updated that way to get it to run.

def input_function(x: Tensor) -> Tensor:
    with torch.no_grad():
        means, variances = [], []
        # Since we're only looking at mean & variance, we can freely
        # use mini-batches.
        x = x.to(next(self.model.parameters()).device)  # get x to the same device as the model <---- NEW
        for x_split in x.split(split_size=mini_batch_size):                    
            p = assert_is_instance(
                self.model.posterior(x_split),
                GPyTorchPosterior,
            )
            means.append(p.mean)
            variances.append(p.variance)

        cat_dim = 1 if is_ensemble(self.model) else 0
        return link_function(
            torch.cat(means, dim=cat_dim), torch.cat(variances, dim=cat_dim)
        )

Please paste any relevant traceback/logs produced by the example provided.

Ax Version

1.0.0

Python Version

3.13.2

Operating System

Ubuntu

(Optional) Describe any potential fixes you've considered to the issue outlined above.

No response

Pull Request

None

Code of Conduct

  • I agree to follow Ax's Code of Conduct

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions