Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def condition_on_observations(
"""
# pass the transformed data to get_fantasy_model below
# (unless we've already transformed if BatchedMultiOutputGPyTorchModel)
X_original = X.clone()
X = self.transform_inputs(X)

Yvar = noise
Expand All @@ -270,9 +271,9 @@ def condition_on_observations(
if hasattr(fantasy_model, "input_transform"):
# Broadcast tensors to compatible shape before concatenating
expand_shape = torch.broadcast_shapes(
X.shape[:-2], fantasy_model._original_train_inputs.shape[:-2]
X_original.shape[:-2], fantasy_model._original_train_inputs.shape[:-2]
)
X_expanded = X.expand(expand_shape + X.shape[-2:])
X_expanded = X_original.expand(expand_shape + X_original.shape[-2:])
orig_expanded = fantasy_model._original_train_inputs.expand(
expand_shape + fantasy_model._original_train_inputs.shape[-2:]
)
Expand Down
46 changes: 45 additions & 1 deletion test/models/test_gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import itertools
import warnings
from functools import partial

import torch
from botorch.acquisition.objective import ScalarizedPosteriorTransform
Expand All @@ -24,7 +25,11 @@
from botorch.models.model import FantasizeMixin
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms import Standardize
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
from botorch.models.transforms.input import (
ChainedInputTransform,
InputTransform,
NumericToCategoricalEncoding,
)
from botorch.models.utils import fantasize
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.sampling.normal import SobolQMCNormalSampler
Expand All @@ -39,6 +44,8 @@
from gpytorch.settings import trace_mode
from torch import Tensor

from torch.nn.functional import one_hot


class SimpleInputTransform(InputTransform, torch.nn.Module):
def __init__(self, transform_on_train: bool) -> None:
Expand Down Expand Up @@ -691,6 +698,43 @@ def test_condition_on_observations_model_list(self):
X=torch.rand(2, 1, **tkwargs), Y=torch.rand(2, 2, **tkwargs)
)

def test_condition_on_observations_input_transform_shape_manipulation(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}

# Create data
X = torch.rand(12, 2, **tkwargs) * 2
Y = 1 - (X - 0.5).norm(dim=-1, keepdim=True)
Y += 0.1 * torch.rand_like(Y)
# Add a categorical feature
new_col = torch.randint(0, 3, (X.shape[0], 1), **tkwargs)
X = torch.cat([X, new_col], dim=1)

train_X = X[:10]
train_Y = Y[:10]

condition_X = X[10:]
condition_Y = Y[10:]

# setup transform and model
input_transform = NumericToCategoricalEncoding(
dim=3,
categorical_features={2: 3},
encoders={2: partial(one_hot, num_classes=3)},
)

model = SimpleGPyTorchModel(
train_X, train_Y, input_transform=input_transform
)
model.eval()
_ = model.posterior(train_X)

conditioned_model = model.condition_on_observations(
condition_X, condition_Y
)
self.assertAllClose(conditioned_model._original_train_inputs, X)
self.assertAllClose(conditioned_model.train_inputs[0], input_transform(X))

def test_condition_on_observations_input_transform_consistency(self):
"""Test that input transforms are applied consistently in
condition_on_observations.
Expand Down