Skip to content

Commit

Permalink
Merge branch 'devel' into peter/normalize_kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
stoprightthere committed Aug 23, 2023
2 parents 87d0e07 + e5503d2 commit d92f741
Show file tree
Hide file tree
Showing 11 changed files with 549 additions and 17 deletions.
20 changes: 10 additions & 10 deletions geometric_kernels/kernels/feature_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import lab as B

from geometric_kernels.kernels import MaternKarhunenLoeveKernel
from geometric_kernels.lab_extras import from_numpy
from geometric_kernels.lab_extras import float_like, from_numpy
from geometric_kernels.sampling.probability_densities import (
base_density_sample,
hyperbolic_density_sample,
Expand Down Expand Up @@ -64,8 +64,8 @@ def _map(X: B.Numeric, params, state, **kwargs) -> B.Numeric:
eigenfunctions = Phi.__call__(X, **params) # [N, M]

_context: Dict[str, str] = {} # no context
features = B.cast(B.dtype(X), eigenfunctions) * B.cast(
B.dtype(X), weights
features = B.cast(float_like(X), eigenfunctions) * B.cast(
float_like(X), weights
) # [N, M]
return features, _context

Expand Down Expand Up @@ -135,11 +135,11 @@ def _map(X: B.Numeric, params, state, key, **kwargs) -> B.Numeric:
Phi = state["eigenfunctions"]

# X [N, D]
random_phases_b = B.cast(B.dtype(X), from_numpy(X, random_phases))
random_phases_b = B.cast(float_like(X), from_numpy(X, random_phases))
embedding = B.cast(
B.dtype(X), Phi.phi_product(X, random_phases_b, **params)
float_like(X), Phi.phi_product(X, random_phases_b, **params)
) # [N, O, L]
weights_t = B.cast(B.dtype(X), B.transpose(weights))
weights_t = B.cast(float_like(X), B.transpose(weights))

features = B.reshape(embedding * weights_t, B.shape(X)[0], -1) # [N, O*L]
_context: Dict[str, str] = {"key": key}
Expand Down Expand Up @@ -208,10 +208,10 @@ def _map(X: B.Numeric, params, state, key, **kwargs) -> B.Numeric:

# X [N, D]
random_phases_b = B.expand_dims(
B.cast(B.dtype(X), from_numpy(X, random_phases))
B.cast(float_like(X), from_numpy(X, random_phases))
) # [1, O, D]
random_lambda_b = B.expand_dims(
B.cast(B.dtype(X), from_numpy(X, random_lambda))
B.cast(float_like(X), from_numpy(X, random_lambda))
) # [1, O, P]
X_b = B.expand_dims(X, axis=-2) # [N, 1, D]

Expand Down Expand Up @@ -267,10 +267,10 @@ def _map(X: B.Numeric, params, state, key, **kwargs) -> B.Numeric:

# X [N, D]
random_phases_b = B.expand_dims(
B.cast(B.dtype(X), from_numpy(X, random_phases))
B.cast(float_like(X), from_numpy(X, random_phases))
) # [1, O, D]
random_lambda_b = B.expand_dims(
B.cast(B.dtype(X), from_numpy(X, random_lambda))
B.cast(float_like(X), from_numpy(X, random_lambda))
) # [1, O]
X_b = B.expand_dims(X, axis=-2) # [N, 1, D]

Expand Down
6 changes: 3 additions & 3 deletions geometric_kernels/kernels/geometric_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def init_params_and_state(self):
:return: tuple(params, state)
"""
params = dict(lengthscale=np.array(1.0), nu=np.array(0.5))
params = dict(lengthscale=np.array(1.0), nu=np.array(np.inf))

eigenvalues_laplacian = self.space.get_eigenvalues(self.num_eigenfunctions)
eigenfunctions = self.space.get_eigenfunctions(self.num_eigenfunctions)
Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(self, space: Space, feature_map, key):
self.feature_map = make_deterministic(feature_map, key)

def init_params_and_state(self):
params = dict(nu=np.array(0.5), lengthscale=np.array(1.0))
params = dict(nu=np.array(np.inf), lengthscale=np.array(1.0))
state = dict()
return params, state

Expand Down Expand Up @@ -220,7 +220,7 @@ def init_params_and_state(self):
:return: tuple(params, state)
"""
params = dict(lengthscale=1.0, nu=0.5)
params = dict(lengthscale=1.0, nu=np.inf)
state = dict()

return params, state
Expand Down
18 changes: 18 additions & 0 deletions geometric_kernels/lab_extras/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lab import dispatch
from lab.util import abstract
from plum import Union
from scipy.sparse import spmatrix


@dispatch
Expand Down Expand Up @@ -94,6 +95,15 @@ def dtype_double(reference):
"""


@dispatch
@abstract()
def float_like(reference: B.Numeric):
"""
Return the type of the reference if it is a floating point type.
Otherwise return `double` dtype of a backend based on the reference.
"""


@dispatch
@abstract()
def dtype_integer(reference):
Expand Down Expand Up @@ -169,3 +179,11 @@ def cumsum(a: B.Numeric, axis=None):
"""
Return cumulative sum (optionally along axis)
"""


@dispatch
@abstract()
def reciprocal_no_nan(x: Union[B.Numeric, spmatrix]):
"""
Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
"""
23 changes: 23 additions & 0 deletions geometric_kernels/lab_extras/jax/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ def dtype_double(reference: B.JAXRandomState): # type: ignore
return jnp.float64


@dispatch
def float_like(reference: B.JAXNumeric):
"""
Return the type of the reference if it is a floating point type.
Otherwise return `double` dtype of a backend based on the reference.
"""
reference_dtype = reference.dtype
if jnp.issubdtype(reference_dtype, jnp.floating):
return B.dtype(reference)
else:
return jnp.float64


@dispatch
def dtype_integer(reference: B.JAXRandomState): # type: ignore
"""
Expand Down Expand Up @@ -155,3 +168,13 @@ def cumsum(x: B.JAXNumeric, axis=None):
Return cumulative sum (optionally along axis)
"""
return jnp.cumsum(x, axis=axis)


@dispatch
def reciprocal_no_nan(x: B.JAXNumeric):
"""
Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
"""
x_is_zero = jnp.equal(x, 0.0)
safe_x = jnp.where(x_is_zero, 1.0, x)
return jnp.where(x_is_zero, 0.0, jnp.reciprocal(safe_x))
32 changes: 32 additions & 0 deletions geometric_kernels/lab_extras/numpy/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from lab import dispatch
from plum import Union
from scipy.sparse import spmatrix

_Numeric = Union[B.Number, B.NPNumeric]

Expand Down Expand Up @@ -68,6 +69,19 @@ def dtype_double(reference: B.NPRandomState): # type: ignore
return np.float64


@dispatch
def float_like(reference: B.NPNumeric):
"""
Return the type of the reference if it is a floating point type.
Otherwise return `double` dtype of a backend based on the reference.
"""
reference_dtype = reference.dtype
if np.issubdtype(reference_dtype, np.floating):
return reference_dtype
else:
return np.float64


@dispatch
def dtype_integer(reference: B.NPRandomState): # type: ignore
"""
Expand Down Expand Up @@ -144,3 +158,21 @@ def cumsum(a: _Numeric, axis=None):
Return cumulative sum (optionally along axis)
"""
return np.cumsum(a, axis=axis)


@dispatch
def reciprocal_no_nan(x: B.NPNumeric):
"""
Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
"""
x_is_zero = np.equal(x, 0.0)
safe_x = np.where(x_is_zero, 1.0, x)
return np.where(x_is_zero, 0.0, np.reciprocal(safe_x))


@dispatch
def reciprocal_no_nan(x: spmatrix):
"""
Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
"""
return x._with_data(reciprocal_no_nan(x._deduped_data().copy()), copy=True)
21 changes: 21 additions & 0 deletions geometric_kernels/lab_extras/tensorflow/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ def dtype_double(reference: B.TFRandomState): # type: ignore
return tf.float64


@dispatch
def float_like(reference: B.TFNumeric):
"""
Return the type of the reference if it is a floating point type.
Otherwise return `double` dtype of a backend based on the reference.
"""
reference_dtype = reference.dtype
if reference_dtype.is_floating:
return reference_dtype
else:
return tf.float64


@dispatch
def dtype_integer(reference: B.TFRandomState): # type: ignore
"""
Expand Down Expand Up @@ -169,3 +182,11 @@ def cumsum(x: B.TFNumeric, axis=None):
Return cumulative sum (optionally along axis)
"""
return tf.math.cumsum(x, axis=axis)


@dispatch
def reciprocal_no_nan(x: B.TFNumeric):
"""
Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
"""
return tf.math.reciprocal_no_nan(x)
21 changes: 21 additions & 0 deletions geometric_kernels/lab_extras/torch/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ def dtype_double(reference: B.TorchRandomState): # type: ignore
return torch.double


@dispatch
def float_like(reference: B.TorchNumeric):
"""
Return the type of the reference if it is a floating point type.
Otherwise return `double` dtype of a backend based on the reference.
"""
if torch.is_floating_point(reference):
return B.dtype(reference)
else:
return torch.float64


@dispatch
def dtype_integer(reference: B.TorchRandomState): # type: ignore
"""
Expand Down Expand Up @@ -176,3 +188,12 @@ def cumsum(x: B.TorchNumeric, axis=None):
Return cumulative sum (optionally along axis)
"""
return torch.cumsum(x, dim=axis)


@dispatch
def reciprocal_no_nan(x: B.TorchNumeric):
"""
Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
"""
safe_x = torch.where(x == 0.0, 1.0, x)
return torch.where(x == 0.0, 0.0, torch.reciprocal(safe_x))
3 changes: 2 additions & 1 deletion geometric_kernels/sampling/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import lab as B

from geometric_kernels.lab_extras import float_like
from geometric_kernels.types import FeatureMap


Expand All @@ -24,7 +25,7 @@ def sample_at(feature_map, s, X: B.Numeric, params, state, key=None) -> Tuple[An

num_features = B.shape(features)[-1]

key, random_weights = B.randn(key, B.dtype(features), num_features, s) # [M, S]
key, random_weights = B.randn(key, float_like(X), num_features, s) # [M, S]

random_sample = B.matmul(features, random_weights) # [N, S]

Expand Down
10 changes: 8 additions & 2 deletions geometric_kernels/spaces/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
import lab as B
import numpy as np

from geometric_kernels.lab_extras import degree, dtype_integer, eigenpairs, set_value
from geometric_kernels.lab_extras import (
degree,
dtype_integer,
eigenpairs,
reciprocal_no_nan,
set_value,
)
from geometric_kernels.spaces.base import (
ConvertEigenvectorsToEigenfunctions,
DiscreteSpectrumSpace,
Expand Down Expand Up @@ -53,7 +59,7 @@ def set_laplacian(self, adjacency, normalize_laplacian=False):
degree_matrix = degree(adjacency)
self._laplacian = degree_matrix - adjacency
if normalize_laplacian:
degree_inv_sqrt = B.linear_algebra.pinv(B.sqrt(degree_matrix))
degree_inv_sqrt = reciprocal_no_nan(B.sqrt(degree_matrix))
self._laplacian = degree_inv_sqrt @ self._laplacian @ degree_inv_sqrt

def get_eigensystem(self, num):
Expand Down
Loading

0 comments on commit d92f741

Please sign in to comment.