Skip to content

Commit 8f9004c

Browse files
esantorellafacebook-github-bot
authored andcommitted
Bring getting-started documentation in line with best practices (#2425)
Summary: ## Motivation Getting-started documentation should be as simple as possible but no simpler. See #2421 for context. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #2425 Test Plan: - ran the example code to ensure it works - built and ran the website locally and checked it looked okay - previewed Markdown in Markdown viewer Reviewed By: Balandat Differential Revision: D59606022 Pulled By: esantorella fbshipit-source-id: 201361879feb26cce309e3c49e8c7a497fabce5e
1 parent 0455dc3 commit 8f9004c

File tree

3 files changed

+44
-30
lines changed

3 files changed

+44
-30
lines changed

README.md

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,35 +150,40 @@ For more details see our [Documentation](https://botorch.org/docs/introduction)
150150
```python
151151
import torch
152152
from botorch.models import SingleTaskGP
153+
from botorch.models.transforms import Normalize, Standardize
153154
from botorch.fit import fit_gpytorch_mll
154155
from gpytorch.mlls import ExactMarginalLogLikelihood
155156

156157
# Double precision is highly recommended for GPs.
157158
# See https://github.com/pytorch/botorch/discussions/1444
158-
train_X = torch.rand(10, 2, dtype=torch.double)
159+
train_X = torch.rand(10, 2, dtype=torch.double) * 2
159160
Y = 1 - (train_X - 0.5).norm(dim=-1, keepdim=True) # explicit output dimension
160161
Y += 0.1 * torch.rand_like(Y)
161-
train_Y = (Y - Y.mean()) / Y.std()
162162

163-
gp = SingleTaskGP(train_X, train_Y)
163+
gp = SingleTaskGP(
164+
train_X=train_X,
165+
train_Y=Y,
166+
input_transform=Normalize(d=2),
167+
outcome_transform=Standardize(m=1),
168+
)
164169
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
165170
fit_gpytorch_mll(mll)
166171
```
167172

168173
2. Construct an acquisition function
169174
```python
170-
from botorch.acquisition import UpperConfidenceBound
175+
from botorch.acquisition import LogExpectedImprovement
171176

172-
UCB = UpperConfidenceBound(gp, beta=0.1)
177+
logNEI = LogExpectedImprovement(model=gp, best_f=Y.max())
173178
```
174179

175180
3. Optimize the acquisition function
176181
```python
177182
from botorch.optim import optimize_acqf
178183

179-
bounds = torch.stack([torch.zeros(2), torch.ones(2)])
184+
bounds = torch.stack([torch.zeros(2), torch.ones(2)]).to(torch.double)
180185
candidate, acq_value = optimize_acqf(
181-
UCB, bounds=bounds, q=1, num_restarts=5, raw_samples=20,
186+
logNEI, bounds=bounds, q=1, num_restarts=5, raw_samples=20,
182187
)
183188
```
184189

docs/getting_started.md

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,34 +40,39 @@ Here's a quick run down of the main components of a Bayesian Optimization loop.
4040
```python
4141
import torch
4242
from botorch.models import SingleTaskGP
43+
from botorch.models.transforms import Normalize, Standardize
4344
from botorch.fit import fit_gpytorch_mll
4445
from gpytorch.mlls import ExactMarginalLogLikelihood
45-
from botorch.models.transforms.outcome import Standardize
4646

47-
train_X = torch.rand(10, 2, dtype=torch.float64)
47+
train_X = torch.rand(10, 2, dtype=torch.double) * 2
4848
# explicit output dimension -- Y is 10 x 1
4949
train_Y = 1 - (train_X - 0.5).norm(dim=-1, keepdim=True)
5050
train_Y += 0.1 * torch.rand_like(train_Y)
5151

52-
gp = SingleTaskGP(train_X, train_Y, outcome_transform=Standardize(m=1))
52+
gp = SingleTaskGP(
53+
train_X=train_X,
54+
train_Y=Y,
55+
input_transform=Normalize(d=2),
56+
outcome_transform=Standardize(m=1),
57+
)
5358
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
5459
fit_gpytorch_mll(mll)
5560
```
5661

5762
2. Construct an acquisition function
5863
```python
59-
from botorch.acquisition import UpperConfidenceBound
64+
from botorch.acquisition import LogExpectedImprovement
6065

61-
UCB = UpperConfidenceBound(gp, beta=0.1)
66+
logNEI = LogExpectedImprovement(model=gp, best_f=Y.max())
6267
```
6368

6469
3. Optimize the acquisition function
6570
```python
6671
from botorch.optim import optimize_acqf
6772

68-
bounds = torch.stack([torch.zeros(2), torch.ones(2)])
73+
bounds = torch.stack([torch.zeros(2), torch.ones(2)]).to(torch.double)
6974
candidate, acq_value = optimize_acqf(
70-
UCB, bounds=bounds, q=1, num_restarts=5, raw_samples=20,
75+
logNEI, bounds=bounds, q=1, num_restarts=5, raw_samples=20,
7176
)
7277
```
7378

website/pages/en/index.js

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ const bash = (...args) => `~~~bash\n${String.raw(...args)}\n~~~`;
1919

2020
class HomeSplash extends React.Component {
2121
render() {
22-
const {siteConfig, language = ''} = this.props;
23-
const {baseUrl, docsUrl} = siteConfig;
22+
const { siteConfig, language = '' } = this.props;
23+
const { baseUrl, docsUrl } = siteConfig;
2424
const docsPart = `${docsUrl ? `${docsUrl}/` : ''}`;
2525
const langPart = `${language ? `${language}/` : ''}`;
2626
const docUrl = doc => `${baseUrl}${docsPart}${langPart}${doc}`;
@@ -79,8 +79,8 @@ class HomeSplash extends React.Component {
7979

8080
class Index extends React.Component {
8181
render() {
82-
const {config: siteConfig, language = ''} = this.props;
83-
const {baseUrl} = siteConfig;
82+
const { config: siteConfig, language = '' } = this.props;
83+
const { baseUrl } = siteConfig;
8484

8585
const Block = props => (
8686
<Container
@@ -114,34 +114,38 @@ class Index extends React.Component {
114114
const modelFitCodeExample = `${pre}python
115115
import torch
116116
from botorch.models import SingleTaskGP
117+
from botorch.models.transforms import Normalize, Standardize
117118
from botorch.fit import fit_gpytorch_mll
118-
from botorch.utils import standardize
119119
from gpytorch.mlls import ExactMarginalLogLikelihood
120120
121-
train_X = torch.rand(10, 2, dtype=torch.double)
121+
train_X = torch.rand(10, 2, dtype=torch.double) * 2
122122
Y = 1 - torch.linalg.norm(train_X - 0.5, dim=-1, keepdim=True)
123123
Y = Y + 0.1 * torch.randn_like(Y) # add some noise
124-
train_Y = standardize(Y)
125124
126-
gp = SingleTaskGP(train_X, train_Y)
125+
gp = SingleTaskGP(
126+
train_X=train_X,
127+
train_Y=Y,
128+
input_transform=Normalize(d=2),
129+
outcome_transform=Standardize(m=1),
130+
)
127131
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
128132
fit_gpytorch_mll(mll)
129133
`;
130134
// Example for defining an acquisition function
131135
const constrAcqFuncExample = `${pre}python
132-
from botorch.acquisition import UpperConfidenceBound
136+
from botorch.acquisition import LogExpectedImprovement
133137
134-
UCB = UpperConfidenceBound(gp, beta=0.1)
138+
logNEI = LogExpectedImprovement(model=gp, best_f=Y.max())
135139
`;
136140
// Example for optimizing candidates
137141
const optAcqFuncExample = `${pre}python
138142
from botorch.optim import optimize_acqf
139143
140-
bounds = torch.stack([torch.zeros(2), torch.ones(2)])
144+
bounds = torch.stack([torch.zeros(2), torch.ones(2)]).to(torch.double)
141145
candidate, acq_value = optimize_acqf(
142-
UCB, bounds=bounds, q=1, num_restarts=5, raw_samples=20,
146+
logNEI, bounds=bounds, q=1, num_restarts=5, raw_samples=20,
143147
)
144-
candidate # tensor([0.4887, 0.5063])
148+
candidate # tensor([[0.2981, 0.2401]], dtype=torch.float64)
145149
`;
146150
const papertitle = `BoTorch: A Framework for Efficient Monte-Carlo Bayesian Optimization`
147151
const paper_bibtex = `${pre}plaintext
@@ -158,7 +162,7 @@ candidate # tensor([0.4887, 0.5063])
158162
<div
159163
className="productShowcaseSection"
160164
id="quickstart"
161-
style={{textAlign: 'center'}}>
165+
style={{ textAlign: 'center' }}>
162166
<h2>Get Started</h2>
163167
<Container>
164168
<ol>
@@ -187,7 +191,7 @@ candidate # tensor([0.4887, 0.5063])
187191
);
188192

189193
const Features = () => (
190-
<div className="productShowcaseSection" style={{textAlign: 'center'}}>
194+
<div className="productShowcaseSection" style={{ textAlign: 'center' }}>
191195
<h2>Key Features</h2>
192196
<Block layout="threeColumn">
193197
{[
@@ -221,7 +225,7 @@ candidate # tensor([0.4887, 0.5063])
221225
<div
222226
className="productShowcaseSection"
223227
id="reference"
224-
style={{textAlign: 'center'}}>
228+
style={{ textAlign: 'center' }}>
225229
<h2>References</h2>
226230
<Container>
227231
<a href={`https://arxiv.org/abs/1910.06403`}>{papertitle}</a>

0 commit comments

Comments
 (0)