Skip to content
This repository was archived by the owner on Jun 14, 2024. It is now read-only.

Fix a bug for SVGP regression with minibatch traning #148

Merged
merged 2 commits into from
Jan 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mxfusion/components/distributions/gp/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@
add_kernel
kernel
linear
matern
rbf
static
stationary
"""

__all__ = ['add_kernel', 'kernel', 'linear', 'rbf', 'static', 'stationary']
__all__ = ['add_kernel', 'kernel', 'linear', 'matern', 'rbf', 'static',
'stationary']

from .add_kernel import AddKernel
from .rbf import RBF
from .linear import Linear
from .static import Bias, White
from .matern import Matern52, Matern32, Matern12
13 changes: 6 additions & 7 deletions mxfusion/components/distributions/gp/kernels/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from .kernel import NativeKernel
from ....variables import Variable
from ....variables import PositiveTransformation
from .....util.customop import broadcast_to_w_samples


class Linear(NativeKernel):
Expand Down Expand Up @@ -74,20 +73,20 @@ def _compute_K(self, F, X, variances, X2=None):
:rtype: MXNet NDArray or MXNet Symbol
"""
if self.ARD:
var_sqrt = F.sqrt(variances)
var_sqrt = F.expand_dims(F.sqrt(variances), axis=-2)
if X2 is None:
xsc = X * broadcast_to_w_samples(F, var_sqrt, X.shape)
xsc = X * var_sqrt
return F.linalg.syrk(xsc)
else:
xsc = X * broadcast_to_w_samples(F, var_sqrt, X.shape)
x2sc = X2 * broadcast_to_w_samples(F, var_sqrt, X2.shape)
xsc = X * var_sqrt
x2sc = X2 * var_sqrt
return F.linalg.gemm2(xsc, x2sc, False, True)
else:
if X2 is None:
A = F.linalg.syrk(X)
else:
A = F.linalg.gemm2(X, X2, False, True)
return A * broadcast_to_w_samples(F, variances, A.shape)
return A * F.expand_dims(variances, axis=-1)

def _compute_Kdiag(self, F, X, variances):
"""
Expand All @@ -102,7 +101,7 @@ def _compute_Kdiag(self, F, X, variances):
:rtype: MXNet NDArray or MXNet Symbol
"""
X2 = F.square(X)
return F.sum(X2 * broadcast_to_w_samples(F, variances, X2.shape),
return F.sum(X2 * F.expand_dims(variances, axis=-2),
axis=-1)

def replicate_self(self, attribute_map=None):
Expand Down
150 changes: 150 additions & 0 deletions mxfusion/components/distributions/gp/kernels/matern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================


import numpy as np
import mxnet as mx
from .stationary import StationaryKernel


class Matern(StationaryKernel):
"""
Matern kernel:

.. math::
k(r^2) = \\sigma^2 \\exp \\bigg(- \\frac{1}{2} r^2 \\bigg)

:param input_dim: the number of dimensions of the kernel. (The total number of active dimensions)
:type input_dim: int
:param ARD: a binary switch for Automatic Relevance Determination (ARD). If true, the squared distance is divided by a lengthscale for individual
dimensions.
:type ARD: boolean
:param variance: the initial value for the variance parameter (scalar), which scales the whole covariance matrix.
:type variance: float or MXNet NDArray
:param lengthscale: the initial value for the lengthscale parameter.
:type lengthscale: float or MXNet NDArray
:param name: the name of the kernel. The name is used to access kernel parameters.
:type name: str
:param active_dims: The dimensions of the inputs that are taken for the covariance matrix computation. (default: None, taking all the dimensions).
:type active_dims: [int] or None
:param dtype: the data type for float point numbers.
:type dtype: numpy.float32 or numpy.float64
:param ctx: the mxnet context (default: None/current context).
:type ctx: None or mxnet.cpu or mxnet.gpu
"""
broadcastable = True

def __init__(self, input_dim, order, ARD=False, variance=1.,
lengthscale=1., name='matern', active_dims=None, dtype=None,
ctx=None):
super(Matern, self).__init__(
input_dim=input_dim, ARD=ARD, variance=variance,
lengthscale=lengthscale, name=name, active_dims=active_dims,
dtype=dtype, ctx=ctx)
self.order = order


class Matern52(Matern):
def __init__(self, input_dim, ARD=False, variance=1., lengthscale=1.,
name='matern52', active_dims=None, dtype=None, ctx=None):
super(Matern52, self).__init__(
input_dim=input_dim, order=2, ARD=ARD, variance=variance,
lengthscale=lengthscale, name=name, active_dims=active_dims,
dtype=dtype, ctx=ctx)

def _compute_K(self, F, X, lengthscale, variance, X2=None):
"""
The internal interface for the actual covariance matrix computation.

:param F: MXNet computation type <mx.sym, mx.nd>.
:param X: the first set of inputs to the kernel.
:type X: MXNet NDArray or MXNet Symbol
:param X2: (optional) the second set of arguments to the kernel. If X2 is None, this computes a square covariance matrix of X. In other words,
X2 is internally treated as X.
:type X2: MXNet NDArray or MXNet Symbol
:param variance: the variance parameter (scalar), which scales the whole covariance matrix.
:type variance: MXNet NDArray or MXNet Symbol
:param lengthscale: the lengthscale parameter.
:type lengthscale: MXNet NDArray or MXNet Symbol
:return: The covariance matrix.
:rtype: MXNet NDArray or MXNet Symbol
"""
R2 = self._compute_R2(F, X, lengthscale, variance, X2=X2)
R = F.sqrt(F.clip(R2, 1e-14, np.inf))
return F.broadcast_mul(
(1+np.sqrt(5)*R+5/3.*R2)*F.exp(-np.sqrt(5)*R),
F.expand_dims(variance, axis=-2))


class Matern32(Matern):
def __init__(self, input_dim, ARD=False, variance=1., lengthscale=1.,
name='matern32', active_dims=None, dtype=None, ctx=None):
super(Matern32, self).__init__(
input_dim=input_dim, order=1, ARD=ARD, variance=variance,
lengthscale=lengthscale, name=name, active_dims=active_dims,
dtype=dtype, ctx=ctx)

def _compute_K(self, F, X, lengthscale, variance, X2=None):
"""
The internal interface for the actual covariance matrix computation.

:param F: MXNet computation type <mx.sym, mx.nd>.
:param X: the first set of inputs to the kernel.
:type X: MXNet NDArray or MXNet Symbol
:param X2: (optional) the second set of arguments to the kernel. If X2 is None, this computes a square covariance matrix of X. In other words,
X2 is internally treated as X.
:type X2: MXNet NDArray or MXNet Symbol
:param variance: the variance parameter (scalar), which scales the whole covariance matrix.
:type variance: MXNet NDArray or MXNet Symbol
:param lengthscale: the lengthscale parameter.
:type lengthscale: MXNet NDArray or MXNet Symbol
:return: The covariance matrix.
:rtype: MXNet NDArray or MXNet Symbol
"""
R2 = self._compute_R2(F, X, lengthscale, variance, X2=X2)
R = F.sqrt(F.clip(R2, 1e-14, np.inf))
return F.broadcast_mul(
(1+np.sqrt(3)*R)*F.exp(-np.sqrt(3)*R),
F.expand_dims(variance, axis=-2))


class Matern12(Matern):
def __init__(self, input_dim, ARD=False, variance=1., lengthscale=1.,
name='matern12', active_dims=None, dtype=None, ctx=None):
super(Matern12, self).__init__(
input_dim=input_dim, order=0, ARD=ARD, variance=variance,
lengthscale=lengthscale, name=name, active_dims=active_dims,
dtype=dtype, ctx=ctx)

def _compute_K(self, F, X, lengthscale, variance, X2=None):
"""
The internal interface for the actual covariance matrix computation.

:param F: MXNet computation type <mx.sym, mx.nd>.
:param X: the first set of inputs to the kernel.
:type X: MXNet NDArray or MXNet Symbol
:param X2: (optional) the second set of arguments to the kernel. If X2 is None, this computes a square covariance matrix of X. In other words,
X2 is internally treated as X.
:type X2: MXNet NDArray or MXNet Symbol
:param variance: the variance parameter (scalar), which scales the whole covariance matrix.
:type variance: MXNet NDArray or MXNet Symbol
:param lengthscale: the lengthscale parameter.
:type lengthscale: MXNet NDArray or MXNet Symbol
:return: The covariance matrix.
:rtype: MXNet NDArray or MXNet Symbol
"""
R = F.sqrt(F.clip(self._compute_R2(F, X, lengthscale, variance, X2=X2),
1e-14, np.inf))
return F.broadcast_mul(
F.exp(-R), F.expand_dims(variance, axis=-2))
3 changes: 1 addition & 2 deletions mxfusion/components/distributions/gp/kernels/rbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


from .stationary import StationaryKernel
from .....util.customop import broadcast_to_w_samples


class RBF(StationaryKernel):
Expand Down Expand Up @@ -69,4 +68,4 @@ def _compute_K(self, F, X, lengthscale, variance, X2=None):
:rtype: MXNet NDArray or MXNet Symbol
"""
R2 = self._compute_R2(F, X, lengthscale, variance, X2=X2)
return F.exp(R2 / -2) * broadcast_to_w_samples(F, variance, R2.shape)
return F.exp(R2 / -2) * F.expand_dims(variance, axis=-1)
10 changes: 5 additions & 5 deletions mxfusion/components/distributions/gp/kernels/stationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from .kernel import NativeKernel
from ....variables import Variable
from ....variables import PositiveTransformation
from .....util.customop import broadcast_to_w_samples


class StationaryKernel(NativeKernel):
Expand Down Expand Up @@ -90,14 +89,14 @@ def _compute_R2(self, F, X, lengthscale, variance, X2=None):
"""
lengthscale = F.expand_dims(lengthscale, axis=-2)
if X2 is None:
xsc = X / broadcast_to_w_samples(F, lengthscale, X.shape)
xsc = X / lengthscale
amat = F.linalg.syrk(xsc) * -2
dg_a = F.sum(F.square(xsc), axis=-1)
amat = F.broadcast_add(amat, F.expand_dims(dg_a, axis=-1))
amat = F.broadcast_add(amat, F.expand_dims(dg_a, axis=-2))
else:
x1sc = X / broadcast_to_w_samples(F, lengthscale, X.shape)
x2sc = X2 / broadcast_to_w_samples(F, lengthscale, X2.shape)
x1sc = X / lengthscale
x2sc = X2 / lengthscale
amat = F.linalg.gemm2(x1sc, x2sc, False, True) * -2
dg1 = F.sum(F.square(x1sc), axis=-1, keepdims=True)
amat = F.broadcast_add(amat, dg1)
Expand All @@ -119,7 +118,8 @@ def _compute_Kdiag(self, F, X, lengthscale, variance):
:return: The covariance matrix.
:rtype: MXNet NDArray or MXNet Symbol
"""
return broadcast_to_w_samples(F, variance, X.shape[:-1])
return F.zeros(shape=X.shape[:-1], dtype=self.dtype,
ctx=self.ctx) + variance

def replicate_self(self, attribute_map=None):
"""
Expand Down
4 changes: 2 additions & 2 deletions mxfusion/inference/minibatch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def run(self, infr_executor, data, param_dict, ctx, optimizer='adam',
loss_for_gradient.backward()
if verbose:
print('\repoch {} Iteration {} loss: {}\t\t\t'.format(
e + 1, i + 1, loss.asscalar() / self.batch_size),
e + 1, i + 1, loss.asscalar()),
end='')
trainer.step(batch_size=self.batch_size,
ignore_stale_grad=True)
L_e += loss.asscalar() / self.batch_size
L_e += loss.asscalar()
n_batches += 1
if verbose:
print('epoch-loss: {} '.format(L_e / n_batches))
9 changes: 9 additions & 0 deletions mxfusion/modules/gp_modules/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class GPRegressionLogPdf(VariationalInference):
The method to compute the logarithm of the probability density function of
a Gaussian process model with Gaussian likelihood.
"""
def __init__(self, model, posterior, observed, jitter=0.):
super(GPRegressionLogPdf, self).__init__(
model=model, posterior=posterior, observed=observed)
self.log_pdf_scaling = 1
self.jitter = jitter

def compute(self, F, variables):
X = variables[self.model.X]
Y = variables[self.model.Y]
Expand All @@ -48,6 +54,9 @@ def compute(self, F, variables):
K = kern.K(F, X, **kern_params) + \
F.expand_dims(F.eye(N, dtype=X.dtype), axis=0) * \
F.expand_dims(noise_var, axis=-2)
if self.jitter > 0.:
K = K + F.expand_dims(F.eye(N, dtype=X.dtype), axis=0) * \
self.jitter
L = F.linalg.potrf(K)

if self.model.mean_func is not None:
Expand Down
9 changes: 5 additions & 4 deletions mxfusion/modules/gp_modules/sparsegp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SparseGPRegressionLogPdf(VariationalInference):
def __init__(self, model, posterior, observed, jitter=0.):
super(SparseGPRegressionLogPdf, self).__init__(
model=model, posterior=posterior, observed=observed)
self.log_pdf_scaling = 1
self.jitter = jitter

def compute(self, F, variables):
Expand Down Expand Up @@ -245,14 +246,14 @@ class SparseGPRegression(Module):
"""

def __init__(self, X, kernel, noise_var, inducing_inputs=None,
inducing_num=10, mean_func=None,
num_inducing=10, mean_func=None,
rand_gen=None, dtype=None, ctx=None):
if not isinstance(X, Variable):
X = Variable(value=X)
if not isinstance(noise_var, Variable):
noise_var = Variable(value=noise_var)
if inducing_inputs is None:
inducing_inputs = Variable(shape=(inducing_num, kernel.input_dim))
inducing_inputs = Variable(shape=(num_inducing, kernel.input_dim))
inputs = [('X', X), ('inducing_inputs', inducing_inputs),
('noise_var', noise_var)]
input_names = [k for k, _ in inputs]
Expand Down Expand Up @@ -341,7 +342,7 @@ def _attach_default_inference_algorithms(self):

@staticmethod
def define_variable(X, kernel, noise_var, shape=None, inducing_inputs=None,
inducing_num=10, mean_func=None, rand_gen=None,
num_inducing=10, mean_func=None, rand_gen=None,
dtype=None, ctx=None):
"""
Creates and returns a variable drawn from a sparse Gaussian process regression.
Expand Down Expand Up @@ -370,7 +371,7 @@ def define_variable(X, kernel, noise_var, shape=None, inducing_inputs=None,
"""
gp = SparseGPRegression(
X=X, kernel=kernel, noise_var=noise_var,
inducing_inputs=inducing_inputs, inducing_num=inducing_num,
inducing_inputs=inducing_inputs, num_inducing=num_inducing,
mean_func=mean_func, rand_gen=rand_gen, dtype=dtype, ctx=ctx)
gp._generate_outputs({'random_variable': shape})
return gp.random_variable
Expand Down
10 changes: 5 additions & 5 deletions mxfusion/modules/gp_modules/svgp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def compute(self, F, variables):
logL = logL + F.sum(F.sum(F.square(LinvKuf)/noise_var_m, axis=-1),
axis=-1)*D/2.
logL = logL + F.sum(F.sum(Linvmu*LinvKufY, axis=-1), axis=-1)
logL = logL + self.model.U.factor.log_pdf_scaling*KL_u
logL = self.log_pdf_scaling*logL + KL_u
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a reference to the doc here (and to the other GP implementations please) about where you get the maths from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This relates to mini-batch learning. We should have a tutorial explaining this.

return logL


Expand Down Expand Up @@ -265,14 +265,14 @@ class SVGPRegression(Module):
"""

def __init__(self, X, kernel, noise_var, inducing_inputs=None,
inducing_num=10, mean_func=None,
num_inducing=10, mean_func=None,
rand_gen=None, dtype=None, ctx=None):
if not isinstance(X, Variable):
X = Variable(value=X)
if not isinstance(noise_var, Variable):
noise_var = Variable(value=noise_var)
if inducing_inputs is None:
inducing_inputs = Variable(shape=(inducing_num, kernel.input_dim))
inducing_inputs = Variable(shape=(num_inducing, kernel.input_dim))
inputs = [('X', X), ('inducing_inputs', inducing_inputs),
('noise_var', noise_var)]
input_names = [k for k, _ in inputs]
Expand Down Expand Up @@ -357,7 +357,7 @@ def _attach_default_inference_algorithms(self):

@staticmethod
def define_variable(X, kernel, noise_var, shape=None, inducing_inputs=None,
inducing_num=10, mean_func=None, rand_gen=None,
num_inducing=10, mean_func=None, rand_gen=None,
dtype=None, ctx=None):
"""
Creates and returns a variable drawn from a Stochastic variational sparse Gaussian process regression with Gaussian likelihood.
Expand Down Expand Up @@ -386,7 +386,7 @@ def define_variable(X, kernel, noise_var, shape=None, inducing_inputs=None,
"""
gp = SVGPRegression(
X=X, kernel=kernel, noise_var=noise_var,
inducing_inputs=inducing_inputs, inducing_num=inducing_num,
inducing_inputs=inducing_inputs, num_inducing=num_inducing,
mean_func=mean_func, rand_gen=rand_gen, dtype=dtype, ctx=ctx)
gp._generate_outputs({'random_variable': shape})
return gp.random_variable
Expand Down
Loading