-
Notifications
You must be signed in to change notification settings - Fork 451
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Bug
When using BoTorch's SingleTaskGP model without specifying an output_transform, the GP model in Google Colab seems to standardize the training targets automatically. In Colab, gp.train_targets is different from Y while on MacOS, the GP model behaves as expected.
This issue might be related to Issue #2533, although the issue there seems to be with 'input_transform' rather than 'output_transform'.
To reproduce
import torch
from botorch.models import SingleTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_mll
import matplotlib.pyplot as plt
torch.manual_seed(124)
dim = 1
train_X = torch.rand(100, dim, dtype=torch.double) * 2
Y = 1 - torch.linalg.norm(train_X - 0.5, dim=-1, keepdim=True)
train_Yvar = torch.full_like(Y, 1e-4) # Adding noise variance
# No output transformation is provided
gp = SingleTaskGP(train_X=train_X, train_Y=Y, train_Yvar=train_Yvar)
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)
print(gp.train_inputs[0][:5])
print(train_X[:5])
print(gp.train_targets[:5])
print(Y[:5])
# Generate test points for evaluation
test_x = torch.linspace(0, 2, 100, dtype=torch.double).unsqueeze(-1)
# Plot the GP predictions
with torch.no_grad():
observed_pred = gp.likelihood(gp(test_x))
lower, upper = observed_pred.confidence_region()
# Create two subplots: one for GP and one for EI
fig, ax = plt.subplots(1, 1, figsize=(8, 5))
# Plot the GP mean and confidence interval on the first subplot
with torch.no_grad():
observed_pred = gp.likelihood(gp(test_x))
lower, upper = observed_pred.confidence_region()
ax.plot(test_x.numpy(), observed_pred.mean.numpy(), 'b', label='GP mean')
ax.fill_between(test_x.numpy().flatten(), lower.numpy().flatten(), upper.numpy().flatten(), alpha=0.2, color='blue')
ax.plot(train_X.numpy(), Y.numpy(), 'ro', label='Training points')
# Enhancements for clarity on the first subplot (Surrogate model)
ax.set_ylabel('y')
ax.set_title('Gaussian Process Surrogate Model')
ax.legend()
plt.show()** Stack trace/error message **
tensor([[0.4418],
[1.4512],
[0.1962],
[0.9559],
[1.7554]], dtype=torch.float64)
tensor([[0.4418],
[1.4512],
[0.1962],
[0.9559],
[1.7554]], dtype=torch.float64)
tensor([ 1.2871, -0.5817, 0.7731, 0.4548, -1.2182], dtype=torch.float64)
tensor([[ 0.9418],
[ 0.0484],
[ 0.6962],
[ 0.5441],
[-0.2554]], dtype=torch.float64)
Expected Behavior
tensor([[0.4418],
[1.4512],
[0.1962],
[0.9559],
[1.7554]], dtype=torch.float64)
tensor([[0.4418],
[1.4512],
[0.1962],
[0.9559],
[1.7554]], dtype=torch.float64)
tensor([ 0.9418, 0.0488, 0.6962, 0.5441, -0.2554], dtype=torch.float64)
tensor([[ 0.9418],
[ 0.0484],
[ 0.6962],
[ 0.5441],
[-0.2554]], dtype=torch.float64)
System information
- BoTorch Version (run
print(botorch.__version__)):- Colab: 0.12.0
- MacOS: 0.12.0 (though
botorch.__version__returnsUnknown)
- GPyTorch Version (run
print(gpytorch.__version__)):- Colab: 1.13
- MacOS: 1.13
- PyTorch Version (run
print(torch.__version__)):- Colab: 2.4.1+cu121
- MacOS: 2.4.1
- Python Version:
- Colab: 3.10.12
- MacOS: 3.11.9
- Computer OS:
- Colab: Linux 6.1.85+
- MacOS: Darwin 24.0.0
Thanks for looking into this!
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working

