Skip to content

Commit

Permalink
Hypersphere space small improvements (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
vabor112 authored Aug 13, 2024
1 parent 1bd9730 commit 8f918cc
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions geometric_kernels/spaces/hypersphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

class SphericalHarmonics(EigenfunctionsWithAdditionTheorem):
r"""
Eigenfunctions Laplace-Beltrami operator on the hypersphere correspond
to the spherical harmonics.
Eigenfunctions of the Laplace-Beltrami operator on the hypersphere
correspond to the spherical harmonics.
Levels are the whole eigenspaces.
Expand Down Expand Up @@ -56,7 +56,7 @@ def num_computed_levels(self) -> int:
break
return num

def __call__(self, X: B.Numeric, **parameters) -> B.Numeric:
def __call__(self, X: B.Numeric, **kwargs) -> B.Numeric:
return self._spherical_harmonics(X)

def _addition_theorem(
Expand Down Expand Up @@ -88,7 +88,7 @@ def _addition_theorem(
An array of shape [N, N2, L].
"""
values = [
level.addition(X, X2)[..., None] # [N1, N2, 1]
level.addition(X, X2)[..., None] # [N, N2, 1]
for level in self._spherical_harmonics.harmonic_levels
]
return B.concat(*values, axis=-1) # [N, N2, L]
Expand All @@ -106,10 +106,10 @@ def _addition_theorem_diag(self, X: B.Numeric, **kwargs) -> B.Numeric:
@property
def num_eigenfunctions(self) -> int:
if self._num_eigenfunctions is None:
n = 0
for d in range(self._num_levels):
n += num_harmonics(self.dim + 1, d)
self._num_eigenfunctions = n
J = 0
for level in range(self.num_levels):
J += num_harmonics(self.dim + 1, level)
self._num_eigenfunctions = J
return self._num_eigenfunctions

@property
Expand Down Expand Up @@ -182,14 +182,10 @@ def get_eigenvalues(self, num: int) -> B.Numeric:

def get_repeated_eigenvalues(self, num: int) -> B.Numeric:
eigenfunctions = SphericalHarmonics(self.dim, num)
eigenvalues_per_level = np.array(
[
level.eigenvalue()
for level in eigenfunctions._spherical_harmonics.harmonic_levels
]
)
eigenvalues_per_level = self.get_eigenvalues(num)

eigenvalues = chain(
eigenvalues_per_level,
B.squeeze(eigenvalues_per_level),
eigenfunctions.num_eigenfunctions_per_level,
) # [J,]
return B.reshape(eigenvalues, -1, 1) # [J, 1]
Expand Down Expand Up @@ -241,7 +237,7 @@ def random(self, key: B.RandomState, number: int) -> B.Numeric:
Either `np.random.RandomState`, `tf.random.Generator`,
`torch.Generator` or `jax.tensor` (representing random state).
:param number:
Number of samples to draw.
Number N of samples to draw.
:return:
An array of `number` uniformly random samples on the space.
Expand Down

0 comments on commit 8f918cc

Please sign in to comment.