Skip to content

Commit 11eb671

Browse files
James Wilsonfacebook-github-bot
James Wilson
authored andcommitted
fit_gyptorch_model refactor (#1134)
Summary: X-link: facebook/Ax#1134 Pull Request resolved: #1371 This commit updates `fit_gpytorch_model` and related methods, with the aim of fixing existing issues and improving extensibility. Key changes are as follow: - Replace `fit_gpytorch_model` with `fit_gpytorch_mll`, a `Dispatcher` backed reimplementation of the original model fitting pipeline. Note that `fit_gpytorch_mll` does **not** pass `kwargs` to `optimizer` and instead introduces an optional `optimizer_kwargs` argument. - Convert `fit_gpytorch_model` into a convenience method for calling `fit_gpytorch_mll` with (limited) support for legacy API. - Add validation for multioutput GP fitting routines based on decomposing a single model into a list of independent models. - Updated unit tests for relevant code paths. Reviewed By: Balandat Differential Revision: D38692173 fbshipit-source-id: 828cff264715cfa84ca4c4361db434574cf8fbf5
1 parent 970664c commit 11eb671

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1891
-738
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ For more details see our [Documentation](https://botorch.org/docs/introduction)
126126
```python
127127
import torch
128128
from botorch.models import SingleTaskGP
129-
from botorch.fit import fit_gpytorch_model
129+
from botorch.fit import fit_gpytorch_mll
130130
from gpytorch.mlls import ExactMarginalLogLikelihood
131131

132132
train_X = torch.rand(10, 2)
@@ -136,7 +136,7 @@ For more details see our [Documentation](https://botorch.org/docs/introduction)
136136

137137
gp = SingleTaskGP(train_X, train_Y)
138138
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
139-
fit_gpytorch_model(mll)
139+
fit_gpytorch_mll(mll)
140140
```
141141

142142
2. Construct an acquisition function

botorch/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
test_functions,
1515
)
1616
from botorch.cross_validation import batch_cross_validation
17-
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_model
17+
from botorch.fit import (
18+
fit_fully_bayesian_model_nuts,
19+
fit_gpytorch_mll,
20+
fit_gpytorch_model,
21+
)
1822
from botorch.generation.gen import (
1923
gen_candidates_scipy,
2024
gen_candidates_torch,
@@ -34,6 +38,7 @@
3438
"batch_cross_validation",
3539
"exceptions",
3640
"fit_fully_bayesian_model_nuts",
41+
"fit_gpytorch_mll",
3742
"fit_gpytorch_model",
3843
"gen_candidates_scipy",
3944
"gen_candidates_torch",

botorch/cross_validation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from typing import Any, Dict, NamedTuple, Optional, Type
1414

1515
import torch
16-
from botorch.fit import fit_gpytorch_model
16+
from botorch.fit import fit_gpytorch_mll
1717
from botorch.models.gpytorch import GPyTorchModel
1818
from botorch.optim.utils import _filter_kwargs
1919
from botorch.posteriors.gpytorch import GPyTorchPosterior
@@ -119,7 +119,7 @@ def batch_cross_validation(
119119
internally. Note: Multi-task GPs are not currently supported.
120120
mll_cls: A MarginalLogLikelihood class.
121121
cv_folds: A CVFolds tuple.
122-
fit_args: Arguments passed along to fit_gpytorch_model
122+
fit_args: Arguments passed along to fit_gpytorch_mll.
123123
124124
Returns:
125125
A CVResults tuple with the following fields
@@ -153,7 +153,7 @@ def batch_cross_validation(
153153
model_cv = model_cls(**_filter_kwargs(model_cls, **kwargs))
154154
mll_cv = mll_cls(model_cv.likelihood, model_cv)
155155
mll_cv.to(cv_folds.train_X)
156-
mll_cv = fit_gpytorch_model(mll_cv, **fit_args)
156+
mll_cv = fit_gpytorch_mll(mll_cv, **fit_args)
157157

158158
# Evaluate on the hold-out set in batch mode
159159
with torch.no_grad():

botorch/exceptions/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@ class BotorchTensorDimensionError(BotorchError):
3737
r"""Exception raised when a tensor violates a botorch convention."""
3838

3939
pass
40+
41+
42+
class ModelFittingError(Exception):
43+
r"""Exception raised when attempts to fit a model terminate unsuccessfully."""
44+
45+
pass

0 commit comments

Comments
 (0)