-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into features/#1117-array-copy-None
- Loading branch information
Showing
8 changed files
with
1,005 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ | |
from .basics import * | ||
from .solver import * | ||
from .qr import * | ||
from .svdtools import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,16 @@ | ||
""" | ||
Future file for SVD functions | ||
file for future "full" SVD implementation | ||
""" | ||
from typing import Tuple | ||
from ..dndarray import DNDarray | ||
|
||
# __all__ = ["qr"] | ||
__all__ = ["svd"] | ||
|
||
|
||
def svd(A: DNDarray) -> Tuple[DNDarray, DNDarray, DNDarray]: | ||
""" | ||
The intended functionality is similar to `numpy.linalg.svd`, but of-course allowing for distributed-memory parallelization and GPU-support. | ||
""" | ||
raise NotImplementedError( | ||
" Memory-distributed 'full' (i.e. non-trucated and non-approximate) SVD not implemented yet. Consider using `heat.linalg.hsvd` for an approximate, truncated SVD instead." | ||
) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import torch | ||
import os | ||
import unittest | ||
import heat as ht | ||
import numpy as np | ||
from mpi4py import MPI | ||
|
||
from ...tests.test_suites.basic_test import TestCase | ||
|
||
|
||
class TestHSVD(TestCase): | ||
def test_hsvd_rank_part1(self): | ||
nprocs = MPI.COMM_WORLD.Get_size() | ||
test_matrices = [ | ||
ht.random.randn(50, 15 * nprocs, dtype=ht.float32, split=1), | ||
ht.random.randn(50, 15 * nprocs, dtype=ht.float64, split=1), | ||
ht.random.randn(15 * nprocs, 50, dtype=ht.float32, split=0), | ||
ht.random.randn(15 * nprocs, 50, dtype=ht.float64, split=0), | ||
ht.random.randn(15 * nprocs, 50, dtype=ht.float32, split=None), | ||
ht.random.randn(50, 15 * nprocs, dtype=ht.float64, split=None), | ||
ht.zeros((50, 15 * nprocs), dtype=ht.float32, split=1), | ||
] | ||
rtols = [1e-1, 1e-2, 1e-3] | ||
ranks = [5, 10, 15] | ||
|
||
# check if hsvd yields "reasonable" results for random matrices, i.e. | ||
# U (resp. V) is orthogonal for split=1 (resp. split=0) | ||
# hsvd_rank yields the correct rank | ||
# the true reconstruction error is <= error estimate | ||
# for hsvd_rtol: true reconstruction error <= rtol (provided no further options) | ||
|
||
for A in test_matrices: | ||
if A.dtype == ht.float64: | ||
dtype_tol = 1e-8 | ||
if A.dtype == ht.float32: | ||
dtype_tol = 1e-3 | ||
|
||
for r in ranks: | ||
U, sigma, V, err_est = ht.linalg.hsvd_rank(A, r, compute_sv=True, silent=True) | ||
hsvd_rk = U.shape[1] | ||
|
||
if ht.norm(A) > 0: | ||
self.assertEqual(hsvd_rk, r) | ||
if A.split == 1: | ||
U_orth_err = ( | ||
ht.norm( | ||
U.T @ U | ||
- ht.eye(hsvd_rk, dtype=U.dtype, split=U.T.split, device=U.device) | ||
) | ||
/ hsvd_rk**0.5 | ||
) | ||
self.assertTrue(U_orth_err <= dtype_tol) | ||
if A.split == 0: | ||
V_orth_err = ( | ||
ht.norm( | ||
V.T @ V | ||
- ht.eye(hsvd_rk, dtype=V.dtype, split=V.T.split, device=V.device) | ||
) | ||
/ hsvd_rk**0.5 | ||
) | ||
self.assertTrue(V_orth_err <= dtype_tol) | ||
true_rel_err = ht.norm(U @ ht.diag(sigma) @ V.T - A) / ht.norm(A) | ||
self.assertTrue(true_rel_err <= err_est) | ||
else: | ||
self.assertEqual(hsvd_rk, 1) | ||
self.assertEqual(ht.norm(U), 0) | ||
self.assertEqual(ht.norm(sigma), 0) | ||
self.assertEqual(ht.norm(V), 0) | ||
|
||
# check if wrong parameter choice is caught | ||
with self.assertRaises(RuntimeError): | ||
ht.linalg.hsvd_rank(A, r, maxmergedim=4) | ||
|
||
for tol in rtols: | ||
U, sigma, V, err_est = ht.linalg.hsvd_rtol(A, tol, compute_sv=True, silent=True) | ||
hsvd_rk = U.shape[1] | ||
|
||
if ht.norm(A) > 0: | ||
if A.split == 1: | ||
U_orth_err = ( | ||
ht.norm( | ||
U.T @ U | ||
- ht.eye(hsvd_rk, dtype=U.dtype, split=U.T.split, device=U.device) | ||
) | ||
/ hsvd_rk**0.5 | ||
) | ||
# print(U_orth_err) | ||
self.assertTrue(U_orth_err <= dtype_tol) | ||
if A.split == 0: | ||
V_orth_err = ( | ||
ht.norm( | ||
V.T @ V | ||
- ht.eye(hsvd_rk, dtype=V.dtype, split=V.T.split, device=V.device) | ||
) | ||
/ hsvd_rk**0.5 | ||
) | ||
self.assertTrue(V_orth_err <= dtype_tol) | ||
true_rel_err = ht.norm(U @ ht.diag(sigma) @ V.T - A) / ht.norm(A) | ||
self.assertTrue(true_rel_err <= err_est) | ||
self.assertTrue(true_rel_err <= tol) | ||
else: | ||
self.assertEqual(hsvd_rk, 1) | ||
self.assertEqual(ht.norm(U), 0) | ||
self.assertEqual(ht.norm(sigma), 0) | ||
self.assertEqual(ht.norm(V), 0) | ||
|
||
# check if wrong parameter choices are catched | ||
with self.assertRaises(ValueError): | ||
ht.linalg.hsvd_rtol(A, tol, maxmergedim=4) | ||
with self.assertRaises(ValueError): | ||
ht.linalg.hsvd_rtol(A, tol, maxmergedim=10, maxrank=11) | ||
with self.assertRaises(ValueError): | ||
ht.linalg.hsvd_rtol(A, tol, no_of_merges=1) | ||
|
||
# check if wrong input arrays are catched | ||
wrong_test_matrices = [ | ||
0, | ||
ht.ones((50, 15 * nprocs), dtype=ht.int8, split=1), | ||
ht.ones((50, 15 * nprocs), dtype=ht.int16, split=1), | ||
ht.ones((50, 15 * nprocs), dtype=ht.int32, split=1), | ||
ht.ones((50, 15 * nprocs), dtype=ht.int64, split=1), | ||
ht.ones((50, 15 * nprocs), dtype=ht.complex64, split=1), | ||
ht.ones((50, 15 * nprocs), dtype=ht.complex128, split=1), | ||
] | ||
|
||
for A in wrong_test_matrices: | ||
with self.assertRaises(TypeError): | ||
ht.linalg.hsvd_rank(A, 5) | ||
with self.assertRaises(TypeError): | ||
ht.linalg.hsvd_rank(A, 1e-1) | ||
|
||
wrong_test_matrices = [ | ||
ht.ones((15, 15 * nprocs, 15), split=1, dtype=ht.float64), | ||
ht.ones(15 * nprocs, split=0, dtype=ht.float64), | ||
] | ||
for wrong_arr in wrong_test_matrices: | ||
with self.assertRaises(ValueError): | ||
ht.linalg.hsvd_rank(wrong_arr, 5) | ||
with self.assertRaises(ValueError): | ||
ht.linalg.hsvd_rtol(wrong_arr, 1e-1) | ||
|
||
# check if compute_sv=False yields the correct number of outputs (=1) | ||
self.assertEqual(len(ht.linalg.hsvd_rank(test_matrices[0], 5)), 2) | ||
self.assertEqual(len(ht.linalg.hsvd_rtol(test_matrices[0], 5e-1)), 2) | ||
|
||
@unittest.skipIf(torch.cuda.is_available() and torch.version.hip, "not supported for HIP") | ||
def test_hsvd_rank_part2(self): | ||
# check if hsvd_rank yields correct results for maxrank <= truerank | ||
# this needs to be skipped on AMD because generation of test data relies on QR... | ||
nprocs = MPI.COMM_WORLD.Get_size() | ||
true_rk = max(10, nprocs) | ||
test_matrices_low_rank = [ | ||
ht.utils.data.matrixgallery.random_known_rank( | ||
50, 15 * nprocs, true_rk, split=1, dtype=ht.float32 | ||
), | ||
ht.utils.data.matrixgallery.random_known_rank( | ||
50, 15 * nprocs, true_rk, split=1, dtype=ht.float32 | ||
), | ||
ht.utils.data.matrixgallery.random_known_rank( | ||
15 * nprocs, 50, true_rk, split=0, dtype=ht.float64 | ||
), | ||
ht.utils.data.matrixgallery.random_known_rank( | ||
15 * nprocs, 50, true_rk, split=0, dtype=ht.float64 | ||
), | ||
] | ||
|
||
for mat in test_matrices_low_rank: | ||
A = mat[0] | ||
if A.dtype == ht.float64: | ||
dtype_tol = 1e-8 | ||
if A.dtype == ht.float32: | ||
dtype_tol = 1e-3 | ||
|
||
for r in [true_rk, true_rk + 2]: | ||
U, s, V, _ = ht.linalg.hsvd_rank(A, r, compute_sv=True) | ||
V = V[:, :true_rk].resplit(V.split) | ||
U = U[:, :true_rk].resplit(U.split) | ||
s = s[:true_rk] | ||
|
||
U_orth_err = ( | ||
ht.norm( | ||
U.T @ U - ht.eye(true_rk, dtype=U.dtype, split=U.T.split, device=U.device) | ||
) | ||
/ true_rk**0.5 | ||
) | ||
V_orth_err = ( | ||
ht.norm( | ||
V.T @ V - ht.eye(true_rk, dtype=V.dtype, split=V.T.split, device=V.device) | ||
) | ||
/ true_rk**0.5 | ||
) | ||
true_rel_err = ht.norm(U @ ht.diag(s) @ V.T - A) / ht.norm(A) | ||
|
||
self.assertTrue(ht.norm(s - mat[1][1]) / ht.norm(mat[1][1]) <= dtype_tol) | ||
self.assertTrue(U_orth_err <= dtype_tol) | ||
self.assertTrue(V_orth_err <= dtype_tol) | ||
self.assertTrue(true_rel_err <= dtype_tol) |
Oops, something went wrong.