From da2b0d1a34c25b2b179141142e605c634f00b18f Mon Sep 17 00:00:00 2001 From: Viacheslav Borovitskiy Date: Wed, 10 Jul 2024 17:11:06 +0200 Subject: [PATCH] Add a missing type cast and fix a typo in kernels/karhunen_loeve.py --- geometric_kernels/kernels/karhunen_loeve.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/geometric_kernels/kernels/karhunen_loeve.py b/geometric_kernels/kernels/karhunen_loeve.py index 4fa6ddfd..5157f30a 100644 --- a/geometric_kernels/kernels/karhunen_loeve.py +++ b/geometric_kernels/kernels/karhunen_loeve.py @@ -198,7 +198,7 @@ def K( weights = B.cast(B.dtype(params["nu"]), self.eigenvalues(params)) # [L, 1] Phi = self.eigenfunctions - K = Phi.weighted_outerproduct(weights, X, X2, **params) # [N, N2] + K = Phi.weighted_outerproduct(weights, X, X2, **kwargs) # [N, N2] if is_complex(K): return B.real(K) else: @@ -210,9 +210,9 @@ def K_diag(self, params, X: B.Numeric, **kwargs) -> B.Numeric: assert "nu" in params assert params["nu"].shape == (1,) - weights = self.eigenvalues(params) # [L, 1] + weights = B.cast(B.dtype(params["nu"]), self.eigenvalues(params)) # [L, 1] Phi = self.eigenfunctions - K_diag = Phi.weighted_outerproduct_diag(weights, X, **params) # [N,] + K_diag = Phi.weighted_outerproduct_diag(weights, X, **kwargs) # [N,] if is_complex(K_diag): return B.real(K_diag) else: