From 8ff38862d2909985229358e488d5bed0cfd3acd1 Mon Sep 17 00:00:00 2001 From: Viacheslav Borovitskiy Date: Sat, 10 Aug 2024 22:08:55 +0200 Subject: [PATCH] Replace opt_einsum's contract with lab's einsum, the latter being a more backend-independent version than the former --- geometric_kernels/feature_maps/probability_densities.py | 3 +-- geometric_kernels/kernels/feature_map.py | 3 +-- geometric_kernels/spaces/eigenfunctions.py | 7 +++---- geometric_kernels/spaces/hyperbolic.py | 3 +-- geometric_kernels/spaces/so.py | 5 ++--- geometric_kernels/spaces/su.py | 5 ++--- tests/kernels/test_product.py | 3 +-- tests/spaces/test_circle.py | 5 ++--- tests/spaces/test_hypersphere.py | 5 ++--- tests/spaces/test_lie_groups.py | 3 +-- 10 files changed, 16 insertions(+), 26 deletions(-) diff --git a/geometric_kernels/feature_maps/probability_densities.py b/geometric_kernels/feature_maps/probability_densities.py index 4f1a73f7..51e524cd 100644 --- a/geometric_kernels/feature_maps/probability_densities.py +++ b/geometric_kernels/feature_maps/probability_densities.py @@ -13,7 +13,6 @@ import lab as B import numpy as np from beartype.typing import Dict, List, Optional, Tuple -from opt_einsum import contract as einsum from sympy import Poly, Product, symbols from geometric_kernels.lab_extras import ( @@ -85,7 +84,7 @@ def student_t_sample( shape_sqrt = B.chol(shape) dtype = dtype or dtype_double(key) key, z = B.randn(key, dtype, *size, n) - z = einsum("...i,ji->...j", z, shape_sqrt) + z = B.einsum("...i,ji->...j", z, shape_sqrt) key, g = B.randgamma( key, diff --git a/geometric_kernels/kernels/feature_map.py b/geometric_kernels/kernels/feature_map.py index c2e531bf..b224be43 100644 --- a/geometric_kernels/kernels/feature_map.py +++ b/geometric_kernels/kernels/feature_map.py @@ -7,7 +7,6 @@ import lab as B import numpy as np from beartype.typing import Dict, Optional -from opt_einsum import contract as einsum from geometric_kernels.feature_maps import FeatureMap from geometric_kernels.kernels.base import BaseGeometricKernel @@ -124,7 +123,7 @@ def K( else: features_X2 = features_X - feature_product = einsum("...no,...mo->...nm", features_X, features_X2) + feature_product = B.einsum("...no,...mo->...nm", features_X, features_X2) return feature_product def K_diag(self, params: Dict[str, B.Numeric], X: B.Numeric, **kwargs): diff --git a/geometric_kernels/spaces/eigenfunctions.py b/geometric_kernels/spaces/eigenfunctions.py index 91439972..8bdb9055 100644 --- a/geometric_kernels/spaces/eigenfunctions.py +++ b/geometric_kernels/spaces/eigenfunctions.py @@ -22,7 +22,6 @@ import lab as B from beartype.typing import List, Optional -from opt_einsum import contract as einsum from geometric_kernels.lab_extras import complex_like, is_complex, take_along_axis @@ -106,7 +105,7 @@ def weighted_outerproduct( else: sum_phi_phi_for_level = B.cast(B.dtype(weights), sum_phi_phi_for_level) - return einsum("id,...nki->...nk", weights, sum_phi_phi_for_level) # [N, N2] + return B.einsum("id,...nki->...nk", weights, sum_phi_phi_for_level) # [N, N2] def weighted_outerproduct_diag( self, weights: B.Numeric, X: B.Numeric, **kwargs @@ -132,7 +131,7 @@ def weighted_outerproduct_diag( else: phi_product_diag = B.cast(B.dtype(weights), phi_product_diag) - return einsum("id,ni->n", weights, phi_product_diag) # [N,] + return B.einsum("id,ni->n", weights, phi_product_diag) # [N,] @abc.abstractmethod def phi_product( @@ -301,7 +300,7 @@ def phi_product( X2 = X Phi_X = self.__call__(X, **kwargs) # [N, J] Phi_X2 = self.__call__(X2, **kwargs) # [N2, J] - return einsum("nl,ml->nml", Phi_X, Phi_X2) # [N, N2, J] + return B.einsum("nl,ml->nml", Phi_X, Phi_X2) # [N, N2, J] def phi_product_diag(self, X: B.Numeric, **kwargs): Phi_X = self.__call__(X, **kwargs) # [N, J] diff --git a/geometric_kernels/spaces/hyperbolic.py b/geometric_kernels/spaces/hyperbolic.py index c07593c5..61b57ebc 100644 --- a/geometric_kernels/spaces/hyperbolic.py +++ b/geometric_kernels/spaces/hyperbolic.py @@ -5,7 +5,6 @@ import geomstats as gs import lab as B from beartype.typing import Optional -from opt_einsum import contract as einsum from geometric_kernels.lab_extras import ( complex_like, @@ -134,7 +133,7 @@ def inner_product(self, vector_a, vector_b): p = 1 diagonal = from_numpy(vector_a, [-1.0] * p + [1.0] * q) # (n+1) diagonal = B.cast(B.dtype(vector_a), diagonal) - return einsum("...i,...i->...", diagonal * vector_a, vector_b) + return B.einsum("...i,...i->...", diagonal * vector_a, vector_b) def inv_harish_chandra(self, lam: B.Numeric) -> B.Numeric: lam = B.squeeze(lam, -1) diff --git a/geometric_kernels/spaces/so.py b/geometric_kernels/spaces/so.py index 3ce0d43c..779d275a 100644 --- a/geometric_kernels/spaces/so.py +++ b/geometric_kernels/spaces/so.py @@ -13,7 +13,6 @@ import lab as B import numpy as np from beartype.typing import List, Tuple -from opt_einsum import contract as einsum from geometric_kernels.lab_extras import dtype_double, from_numpy, qr, take_along_axis from geometric_kernels.spaces.eigenfunctions import Eigenfunctions @@ -300,7 +299,7 @@ def random(self, key: B.RandomState, number: int): # explicit parametrization via the double cover SU(2) = S^3 key, sphere_point = B.random.randn(key, dtype_double(key), number, 4) sphere_point /= B.reshape( - B.sqrt(einsum("ij,ij->i", sphere_point, sphere_point)), -1, 1 + B.sqrt(B.einsum("ij,ij->i", sphere_point, sphere_point)), -1, 1 ) x, y, z, w = (B.reshape(sphere_point[..., i], -1, 1) for i in range(4)) @@ -317,7 +316,7 @@ def random(self, key: B.RandomState, number: int): # qr decomposition is not in the lab package, so numpy is used. key, h = B.random.randn(key, dtype_double(key), number, self.n, self.n) q, r = qr(h, mode="complete") - r_diag_sign = B.sign(einsum("...ii->...i", r)) + r_diag_sign = B.sign(B.einsum("...ii->...i", r)) q *= r_diag_sign[:, None] q_det_sign = B.sign(B.det(q)) q[:, :, 0] *= q_det_sign[:, None] diff --git a/geometric_kernels/spaces/su.py b/geometric_kernels/spaces/su.py index c1b0e5b3..a6c89d23 100644 --- a/geometric_kernels/spaces/su.py +++ b/geometric_kernels/spaces/su.py @@ -12,7 +12,6 @@ import lab as B import numpy as np from beartype.typing import List, Tuple -from opt_einsum import contract as einsum from geometric_kernels.lab_extras import ( complex_conj, @@ -222,7 +221,7 @@ def random(self, key: B.RandomState, number: int): # explicit parametrization via the double cover SU(2) = S_3 key, sphere_point = B.random.randn(key, dtype_double(key), number, 4) sphere_point /= B.reshape( - B.sqrt(einsum("ij,ij->i", sphere_point, sphere_point)), -1, 1 + B.sqrt(B.einsum("ij,ij->i", sphere_point, sphere_point)), -1, 1 ) a = create_complex(sphere_point[..., 0], sphere_point[..., 1]) b = create_complex(sphere_point[..., 2], sphere_point[..., 3]) @@ -235,7 +234,7 @@ def random(self, key: B.RandomState, number: int): key, imag = B.random.randn(key, dtype_double(key), number, self.n, self.n) h = create_complex(real, imag) / B.sqrt(2) q, r = qr(h, mode="complete") - r_diag = einsum("...ii->...i", r) + r_diag = B.einsum("...ii->...i", r) r_diag_inv_phase = complex_conj(r_diag / B.abs(r_diag)) q *= r_diag_inv_phase[:, None] q_det = B.det(q) diff --git a/tests/kernels/test_product.py b/tests/kernels/test_product.py index 7eaa692f..bb345c37 100644 --- a/tests/kernels/test_product.py +++ b/tests/kernels/test_product.py @@ -1,6 +1,5 @@ import lab as B import numpy as np -from opt_einsum import contract as einsum from geometric_kernels.kernels import MaternKarhunenLoeveKernel, ProductGeometricKernel from geometric_kernels.lab_extras.extras import from_numpy @@ -43,7 +42,7 @@ def test_circle_product_eigenfunctions(): weights = B.expand_dims(weights, -1) actual = B.to_numpy(eigenfunctions.weighted_outerproduct(weights, X, X)) - expected = einsum("ni,mi,i->nm", Phi_X, Phi_X2, chained_weights) + expected = B.einsum("ni,mi,i->nm", Phi_X, Phi_X2, chained_weights) np.testing.assert_array_almost_equal(actual, expected) diff --git a/tests/spaces/test_circle.py b/tests/spaces/test_circle.py index fc204936..426ff2a5 100644 --- a/tests/spaces/test_circle.py +++ b/tests/spaces/test_circle.py @@ -3,7 +3,6 @@ import pytest import tensorflow as tf import torch -from opt_einsum import contract as einsum from plum import Tuple from geometric_kernels.kernels import MaternKarhunenLoeveKernel @@ -89,7 +88,7 @@ def test_weighted_outerproduct_with_addition_theorem( Phi_X = eigenfunctions(inputs) Phi_X2 = eigenfunctions(inputs2) - expected = einsum("ni,ki,i->nk", Phi_X, Phi_X2, chained_weights) + expected = B.einsum("ni,ki,i->nk", Phi_X, Phi_X2, chained_weights) np.testing.assert_array_almost_equal(actual, expected) @@ -124,7 +123,7 @@ def test_weighted_outerproduct_diag_with_addition_theorem( actual = eigenfunctions.weighted_outerproduct_diag(weights, inputs) Phi_X = eigenfunctions(inputs) - expected = einsum("ni,i->n", Phi_X**2, chained_weights) + expected = B.einsum("ni,i->n", Phi_X**2, chained_weights) np.testing.assert_array_almost_equal(B.to_numpy(actual), B.to_numpy(expected)) diff --git a/tests/spaces/test_hypersphere.py b/tests/spaces/test_hypersphere.py index 2aa52b58..860dee00 100644 --- a/tests/spaces/test_hypersphere.py +++ b/tests/spaces/test_hypersphere.py @@ -1,7 +1,6 @@ import lab as B import numpy as np import pytest -from opt_einsum import contract as einsum from plum import Tuple from geometric_kernels.spaces.hypersphere import SphericalHarmonics @@ -77,7 +76,7 @@ def test_weighted_outerproduct_with_addition_theorem( Phi_X = eigenfunctions(inputs) Phi_X2 = eigenfunctions(inputs2) - expected = einsum("ni,ki,i->nk", Phi_X, Phi_X2, chained_weights) + expected = B.einsum("ni,ki,i->nk", Phi_X, Phi_X2, chained_weights) np.testing.assert_array_almost_equal(actual, expected) @@ -112,5 +111,5 @@ def test_weighted_outerproduct_diag_with_addition_theorem( actual = eigenfunctions.weighted_outerproduct_diag(weights, inputs) Phi_X = eigenfunctions(inputs) - expected = einsum("ni,i->n", Phi_X**2, chained_weights) + expected = B.einsum("ni,i->n", Phi_X**2, chained_weights) np.testing.assert_array_almost_equal(B.to_numpy(actual), B.to_numpy(expected)) diff --git a/tests/spaces/test_lie_groups.py b/tests/spaces/test_lie_groups.py index bda5a2b6..863580d8 100644 --- a/tests/spaces/test_lie_groups.py +++ b/tests/spaces/test_lie_groups.py @@ -4,7 +4,6 @@ import numpy as np import pytest from numpy.testing import assert_allclose -from opt_einsum import contract as einsum from geometric_kernels.feature_maps import RandomPhaseFeatureMapCompact from geometric_kernels.kernels import MaternKarhunenLoeveKernel @@ -127,6 +126,6 @@ def test_feature_map(group_and_eigf): K_xx = (kernel.K(param, x, x)).real key, embed_x = feature_map(x, param, key=key, normalize=True) - F_xx = (einsum("ni,mi-> nm", embed_x, embed_x.conj())).real + F_xx = (B.einsum("ni,mi-> nm", embed_x, embed_x.conj())).real assert_allclose(K_xx, F_xx, atol=5e-2)