Skip to content

infinite width bnn kernel #2366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions botorch/models/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from botorch.models.kernels.categorical import CategoricalKernel
from botorch.models.kernels.downsampling import DownsamplingKernel
from botorch.models.kernels.exponential_decay import ExponentialDecayKernel
from botorch.models.kernels.infinite_width_bnn import InfiniteWidthBNNKernel
from botorch.models.kernels.linear_truncated_fidelity import (
LinearTruncatedFidelityKernel,
)
Expand All @@ -16,5 +17,6 @@
"CategoricalKernel",
"DownsamplingKernel",
"ExponentialDecayKernel",
"InfiniteWidthBNNKernel",
"LinearTruncatedFidelityKernel",
]
180 changes: 180 additions & 0 deletions botorch/models/kernels/infinite_width_bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Optional, Tuple

import torch
from gpytorch.constraints import Positive
from gpytorch.kernels import Kernel
from torch import Tensor


class InfiniteWidthBNNKernel(Kernel):
r"""Infinite-width BNN kernel.

Defines the GP kernel which is equivalent to performing exact Bayesian
inference on a fully-connected deep neural network with ReLU activations
and i.i.d. priors in the infinite-width limit.
See [Cho2009kernel]_ and [Lee2018deep]_ for details.

.. [Cho2009kernel]
Y. Cho, and L. Saul. Kernel methods for deep learning.
Advances in Neural Information Processing Systems 22. 2009.
.. [Lee2018deep]
J. Lee, Y. Bahri, R. Novak, S. Schoenholz, J. Pennington, and J. Dickstein.
Deep Neural Networks as Gaussian Processes.
International Conference on Learning Representations. 2018.
"""

has_lengthscale = False

def __init__(
self,
depth: int = 3,
batch_shape: Optional[torch.Size] = None,
active_dims: Optional[Tuple[int, ...]] = None,
acos_eps: float = 1e-7,
device: Optional[torch.device] = None,
) -> None:
r"""
Args:
depth: Depth of neural network.
batch_shape: This will set a separate weight/bias var for each batch.
It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf` is
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
param active_dims: Compute the covariance of only a few input dimensions.
The ints corresponds to the indices of the dimensions.
param acos_eps: A small positive value to restrict acos inputs to
:math`[-1 + \epsilon, 1 - \epsilon]`
param device: Device for parameters.
"""
super().__init__(batch_shape=batch_shape, active_dims=active_dims)
self.depth = depth
self.acos_eps = acos_eps

self.register_parameter(
"raw_weight_var",
torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1, device=device)),
)
self.register_constraint("raw_weight_var", Positive())

self.register_parameter(
"raw_bias_var",
torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1, device=device)),
)
self.register_constraint("raw_bias_var", Positive())

@property
def weight_var(self) -> Tensor:
return self.raw_weight_var_constraint.transform(self.raw_weight_var)

@weight_var.setter
def weight_var(self, value) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_weight_var)
self.initialize(
raw_weight_var=self.raw_weight_var_constraint.inverse_transform(value)
)

@property
def bias_var(self) -> Tensor:
return self.raw_bias_var_constraint.transform(self.raw_bias_var)

@bias_var.setter
def bias_var(self, value) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_bias_var)
self.initialize(
raw_bias_var=self.raw_bias_var_constraint.inverse_transform(value)
)

def _initialize_var(self, x: Tensor) -> Tensor:
"""Computes the initial variance of x for layer 0"""
return (
self.weight_var * torch.sum(x * x, dim=-1, keepdim=True) / x.shape[-1]
+ self.bias_var
)

def _update_var(self, K: Tensor, x: Tensor) -> Tensor:
"""Computes the updated variance of x for next layer"""
return self.weight_var * K / 2 + self.bias_var

def k(self, x1: Tensor, x2: Tensor) -> Tensor:
r"""
For single-layer infinite-width neural networks with i.i.d. priors,
the covariance between outputs can be computed by
:math:`K^0(x, x')=\sigma_b^2+\sigma_w^2\frac{x \cdot x'}{d_\text{input}}`.

For deeper networks, we can recursively define the covariance as
:math:`K^l(x, x')=\sigma_b^2+\sigma_w^2
F_\phi(K^{l-1}(x, x'), K^{l-1}(x, x), K^{l-1}(x', x'))`
where :math:`F_\phi` is a deterministic function based on the
activation function :math:`\phi`.

For ReLU activations, this yields the arc-cosine kernel, which can be computed
analytically.

Args:
x1: `batch_shape x n1 x d`-dim Tensor
x2: `batch_shape x n2 x d`-dim Tensor
"""
K_12 = (
self.weight_var * (x1.matmul(x2.transpose(-2, -1)) / x1.shape[-1])
+ self.bias_var
)

for layer in range(self.depth):
if layer == 0:
K_11 = self._initialize_var(x1)
K_22 = self._initialize_var(x2)
else:
K_11 = self._update_var(K_11, x1)
K_22 = self._update_var(K_22, x2)

sqrt_term = torch.sqrt(K_11.matmul(K_22.transpose(-2, -1)))

fraction = K_12 / sqrt_term
fraction = torch.clamp(
fraction, min=-1 + self.acos_eps, max=1 - self.acos_eps
)

theta = torch.acos(fraction)
theta_term = torch.sin(theta) + (torch.pi - theta) * fraction

K_12 = (
self.weight_var / (2 * torch.pi) * sqrt_term * theta_term
+ self.bias_var
)

return K_12

def forward(
self,
x1: Tensor,
x2: Tensor,
diag: Optional[bool] = False,
last_dim_is_batch: Optional[bool] = False,
**params,
) -> Tensor:
"""
Args:
x1: `batch_shape x n1 x d`-dim Tensor
x2: `batch_shape x n2 x d`-dim Tensor
diag: If True, only returns the diagonal of the kernel matrix.
last_dim_is_batch: Not supported by this kernel.
"""
if last_dim_is_batch:
raise RuntimeError("last_dim_is_batch not supported by this kernel.")

if diag:
K = self._initialize_var(x1)
for _ in range(self.depth):
K = self._update_var(K, x1)
return K.squeeze(-1)
else:
return self.k(x1, x2)
5 changes: 4 additions & 1 deletion sphinx/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ Kernels
.. automodule:: botorch.models.kernels.exponential_decay
.. autoclass:: ExponentialDecayKernel

.. automodule:: botorch.models.kernels.infinite_width_bnn
.. autoclass:: InfiniteWidthBNNKernel

.. automodule:: botorch.models.kernels.linear_truncated_fidelity
.. autoclass:: LinearTruncatedFidelityKernel

Expand Down Expand Up @@ -177,4 +180,4 @@ Inducing Point Allocators
Other Utilties
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.models.utils.assorted
:members:
:members:
171 changes: 171 additions & 0 deletions test/models/kernels/test_infinite_width_bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from botorch.models.kernels.infinite_width_bnn import InfiniteWidthBNNKernel
from botorch.utils.testing import BotorchTestCase
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase


class TestInfiniteWidthBNNKernel(BotorchTestCase, BaseKernelTestCase):
def create_kernel_no_ard(self, **kwargs):
return InfiniteWidthBNNKernel(**kwargs)

def test_properties(self):
with self.subTest():
kernel = InfiniteWidthBNNKernel(3)
bias_var_init = torch.tensor(0.2)
kernel.initialize(bias_var=bias_var_init)
actual_value = bias_var_init.view_as(kernel.bias_var)
self.assertLess(torch.linalg.norm(kernel.bias_var - actual_value), 1e-5)
with self.subTest():
kernel = InfiniteWidthBNNKernel(3)
weight_var_init = torch.tensor(0.2)
kernel.initialize(weight_var=weight_var_init)
actual_value = weight_var_init.view_as(kernel.weight_var)
self.assertLess(torch.linalg.norm(kernel.weight_var - actual_value), 1e-5)
with self.subTest():
kernel = InfiniteWidthBNNKernel(5, batch_shape=torch.Size([2]))
bias_var_init = torch.tensor([0.2, 0.01])
kernel.initialize(bias_var=bias_var_init)
actual_value = bias_var_init.view_as(kernel.bias_var)
self.assertLess(torch.linalg.norm(kernel.bias_var - actual_value), 1e-5)
with self.subTest():
kernel = InfiniteWidthBNNKernel(3, batch_shape=torch.Size([2]))
weight_var_init = torch.tensor([1.0, 2.0])
kernel.initialize(weight_var=weight_var_init)
actual_value = weight_var_init.view_as(kernel.weight_var)
self.assertLess(torch.linalg.norm(kernel.weight_var - actual_value), 1e-5)
with self.subTest():
kernel = InfiniteWidthBNNKernel(3, batch_shape=torch.Size([2]))
x = torch.randn(3, 2)
with self.assertRaises(RuntimeError):
kernel(x, x, last_dim_is_batch=True).to_dense()

def test_forward_0(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}
x1 = torch.tensor([[0.1, 0.2], [1.2, 0.4], [2.4, 0.3]]).to(**tkwargs)
x2 = torch.tensor([[4.1, 2.3], [3.9, 0.0]]).to(**tkwargs)
weight_var = 1.0
bias_var = 0.1
kernel = InfiniteWidthBNNKernel(0, device=self.device).initialize(
weight_var=weight_var, bias_var=bias_var
)
kernel.eval()
expected = (
weight_var * (x1.matmul(x2.transpose(-2, -1)) / x1.shape[-1]) + bias_var
).to(**tkwargs)
res = kernel(x1, x2).to_dense()
self.assertAllClose(res, expected)

def test_forward_0_batch(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}
x1 = torch.tensor(
[
[
[0.4960, 0.7680, 0.0880],
[0.1320, 0.3070, 0.6340],
[0.4900, 0.8960, 0.4550],
[0.6320, 0.3480, 0.4010],
[0.0220, 0.1680, 0.2930],
],
[
[0.5180, 0.6970, 0.8000],
[0.1610, 0.2820, 0.6810],
[0.9150, 0.3970, 0.8740],
[0.4190, 0.5520, 0.9520],
[0.0360, 0.1850, 0.3730],
],
]
).to(**tkwargs)
x2 = torch.tensor(
[
[[0.3050, 0.9320, 0.1750], [0.2690, 0.1500, 0.0310]],
[[0.2080, 0.9290, 0.7230], [0.7420, 0.5260, 0.2430]],
]
).to(**tkwargs)
weight_var = torch.tensor([1.0, 2.0]).to(**tkwargs)
bias_var = torch.tensor([0.1, 0.5]).to(**tkwargs)
kernel = InfiniteWidthBNNKernel(
0, batch_shape=[2], device=self.device
).initialize(weight_var=weight_var, bias_var=bias_var)
kernel.eval()
expected = torch.tensor(
[
[
[0.3942, 0.1838],
[0.2458, 0.1337],
[0.4547, 0.1934],
[0.2958, 0.1782],
[0.1715, 0.1134],
],
[
[1.3891, 1.1303],
[1.0252, 0.7889],
[1.2940, 1.2334],
[1.3588, 1.0551],
[0.7994, 0.6431],
],
]
).to(**tkwargs)
res = kernel(x1, x2).to_dense()
self.assertAllClose(res, expected, 0.0001, 0.0001)

def test_forward_2(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}
x1 = torch.tensor(
[
[
[0.4960, 0.7680, 0.0880],
[0.1320, 0.3070, 0.6340],
[0.4900, 0.8960, 0.4550],
[0.6320, 0.3480, 0.4010],
[0.0220, 0.1680, 0.2930],
],
[
[0.5180, 0.6970, 0.8000],
[0.1610, 0.2820, 0.6810],
[0.9150, 0.3970, 0.8740],
[0.4190, 0.5520, 0.9520],
[0.0360, 0.1850, 0.3730],
],
]
).to(**tkwargs)
x2 = torch.tensor(
[
[[0.3050, 0.9320, 0.1750], [0.2690, 0.1500, 0.0310]],
[[0.2080, 0.9290, 0.7230], [0.7420, 0.5260, 0.2430]],
]
).to(**tkwargs)
weight_var = 1.0
bias_var = 0.1
kernel = InfiniteWidthBNNKernel(2, device=self.device).initialize(
weight_var=weight_var, bias_var=bias_var
)
kernel.eval()
expected = torch.tensor(
[
[
[0.2488, 0.1985],
[0.2178, 0.1872],
[0.2641, 0.2036],
[0.2286, 0.1962],
[0.1983, 0.1793],
],
[
[0.2869, 0.2564],
[0.2429, 0.2172],
[0.2820, 0.2691],
[0.2837, 0.2498],
[0.2160, 0.1986],
],
]
).to(**tkwargs)
res = kernel(x1, x2).to_dense()
self.assertAllClose(res, expected, 0.0001, 0.0001)