Skip to content

Commit

Permalink
Merge branch 'main' into features/#1117-array-copy-None
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudiaComito authored Jun 6, 2023
2 parents ca868ea + 1bc1cca commit e882c56
Show file tree
Hide file tree
Showing 8 changed files with 1,005 additions and 85 deletions.
18 changes: 18 additions & 0 deletions benchmarks/cb/linalg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# flake8: noqa
import heat as ht
import torchvision.datasets as datasets
from mpi4py import MPI
from perun.decorator import monitor


Expand Down Expand Up @@ -32,7 +34,23 @@ def lanczos_cpu(n: int = 50):
V, T = ht.lanczos(B, m=n)


@monitor()
def hierachical_svd_rank(data, r):
approx_svd = ht.linalg.hsvd_rank(data, maxrank=r, compute_sv=True, silent=True)


@monitor()
def hierachical_svd_tol(data, tol):
approx_svd = ht.linalg.hsvd_rtol(data, rtol=tol, compute_sv=True, silent=True)


matmul_cpu_split_0()
matmul_cpu_split_1()
qr_cpu()
lanczos_cpu()

data = ht.utils.data.matrixgallery.random_known_rank(
1000, 500 * MPI.COMM_WORLD.Get_size(), 10, split=1, dtype=ht.float32
)[0]
hierachical_svd_rank(data, 10)
hierachical_svd_tol(data, 1e-2)
1 change: 1 addition & 0 deletions heat/core/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .basics import *
from .solver import *
from .qr import *
from .svdtools import *
15 changes: 13 additions & 2 deletions heat/core/linalg/svd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
"""
Future file for SVD functions
file for future "full" SVD implementation
"""
from typing import Tuple
from ..dndarray import DNDarray

# __all__ = ["qr"]
__all__ = ["svd"]


def svd(A: DNDarray) -> Tuple[DNDarray, DNDarray, DNDarray]:
"""
The intended functionality is similar to `numpy.linalg.svd`, but of-course allowing for distributed-memory parallelization and GPU-support.
"""
raise NotImplementedError(
" Memory-distributed 'full' (i.e. non-trucated and non-approximate) SVD not implemented yet. Consider using `heat.linalg.hsvd` for an approximate, truncated SVD instead."
)
530 changes: 530 additions & 0 deletions heat/core/linalg/svdtools.py

Large diffs are not rendered by default.

197 changes: 197 additions & 0 deletions heat/core/linalg/tests/test_svdtools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import torch
import os
import unittest
import heat as ht
import numpy as np
from mpi4py import MPI

from ...tests.test_suites.basic_test import TestCase


class TestHSVD(TestCase):
def test_hsvd_rank_part1(self):
nprocs = MPI.COMM_WORLD.Get_size()
test_matrices = [
ht.random.randn(50, 15 * nprocs, dtype=ht.float32, split=1),
ht.random.randn(50, 15 * nprocs, dtype=ht.float64, split=1),
ht.random.randn(15 * nprocs, 50, dtype=ht.float32, split=0),
ht.random.randn(15 * nprocs, 50, dtype=ht.float64, split=0),
ht.random.randn(15 * nprocs, 50, dtype=ht.float32, split=None),
ht.random.randn(50, 15 * nprocs, dtype=ht.float64, split=None),
ht.zeros((50, 15 * nprocs), dtype=ht.float32, split=1),
]
rtols = [1e-1, 1e-2, 1e-3]
ranks = [5, 10, 15]

# check if hsvd yields "reasonable" results for random matrices, i.e.
# U (resp. V) is orthogonal for split=1 (resp. split=0)
# hsvd_rank yields the correct rank
# the true reconstruction error is <= error estimate
# for hsvd_rtol: true reconstruction error <= rtol (provided no further options)

for A in test_matrices:
if A.dtype == ht.float64:
dtype_tol = 1e-8
if A.dtype == ht.float32:
dtype_tol = 1e-3

for r in ranks:
U, sigma, V, err_est = ht.linalg.hsvd_rank(A, r, compute_sv=True, silent=True)
hsvd_rk = U.shape[1]

if ht.norm(A) > 0:
self.assertEqual(hsvd_rk, r)
if A.split == 1:
U_orth_err = (
ht.norm(
U.T @ U
- ht.eye(hsvd_rk, dtype=U.dtype, split=U.T.split, device=U.device)
)
/ hsvd_rk**0.5
)
self.assertTrue(U_orth_err <= dtype_tol)
if A.split == 0:
V_orth_err = (
ht.norm(
V.T @ V
- ht.eye(hsvd_rk, dtype=V.dtype, split=V.T.split, device=V.device)
)
/ hsvd_rk**0.5
)
self.assertTrue(V_orth_err <= dtype_tol)
true_rel_err = ht.norm(U @ ht.diag(sigma) @ V.T - A) / ht.norm(A)
self.assertTrue(true_rel_err <= err_est)
else:
self.assertEqual(hsvd_rk, 1)
self.assertEqual(ht.norm(U), 0)
self.assertEqual(ht.norm(sigma), 0)
self.assertEqual(ht.norm(V), 0)

# check if wrong parameter choice is caught
with self.assertRaises(RuntimeError):
ht.linalg.hsvd_rank(A, r, maxmergedim=4)

for tol in rtols:
U, sigma, V, err_est = ht.linalg.hsvd_rtol(A, tol, compute_sv=True, silent=True)
hsvd_rk = U.shape[1]

if ht.norm(A) > 0:
if A.split == 1:
U_orth_err = (
ht.norm(
U.T @ U
- ht.eye(hsvd_rk, dtype=U.dtype, split=U.T.split, device=U.device)
)
/ hsvd_rk**0.5
)
# print(U_orth_err)
self.assertTrue(U_orth_err <= dtype_tol)
if A.split == 0:
V_orth_err = (
ht.norm(
V.T @ V
- ht.eye(hsvd_rk, dtype=V.dtype, split=V.T.split, device=V.device)
)
/ hsvd_rk**0.5
)
self.assertTrue(V_orth_err <= dtype_tol)
true_rel_err = ht.norm(U @ ht.diag(sigma) @ V.T - A) / ht.norm(A)
self.assertTrue(true_rel_err <= err_est)
self.assertTrue(true_rel_err <= tol)
else:
self.assertEqual(hsvd_rk, 1)
self.assertEqual(ht.norm(U), 0)
self.assertEqual(ht.norm(sigma), 0)
self.assertEqual(ht.norm(V), 0)

# check if wrong parameter choices are catched
with self.assertRaises(ValueError):
ht.linalg.hsvd_rtol(A, tol, maxmergedim=4)
with self.assertRaises(ValueError):
ht.linalg.hsvd_rtol(A, tol, maxmergedim=10, maxrank=11)
with self.assertRaises(ValueError):
ht.linalg.hsvd_rtol(A, tol, no_of_merges=1)

# check if wrong input arrays are catched
wrong_test_matrices = [
0,
ht.ones((50, 15 * nprocs), dtype=ht.int8, split=1),
ht.ones((50, 15 * nprocs), dtype=ht.int16, split=1),
ht.ones((50, 15 * nprocs), dtype=ht.int32, split=1),
ht.ones((50, 15 * nprocs), dtype=ht.int64, split=1),
ht.ones((50, 15 * nprocs), dtype=ht.complex64, split=1),
ht.ones((50, 15 * nprocs), dtype=ht.complex128, split=1),
]

for A in wrong_test_matrices:
with self.assertRaises(TypeError):
ht.linalg.hsvd_rank(A, 5)
with self.assertRaises(TypeError):
ht.linalg.hsvd_rank(A, 1e-1)

wrong_test_matrices = [
ht.ones((15, 15 * nprocs, 15), split=1, dtype=ht.float64),
ht.ones(15 * nprocs, split=0, dtype=ht.float64),
]
for wrong_arr in wrong_test_matrices:
with self.assertRaises(ValueError):
ht.linalg.hsvd_rank(wrong_arr, 5)
with self.assertRaises(ValueError):
ht.linalg.hsvd_rtol(wrong_arr, 1e-1)

# check if compute_sv=False yields the correct number of outputs (=1)
self.assertEqual(len(ht.linalg.hsvd_rank(test_matrices[0], 5)), 2)
self.assertEqual(len(ht.linalg.hsvd_rtol(test_matrices[0], 5e-1)), 2)

@unittest.skipIf(torch.cuda.is_available() and torch.version.hip, "not supported for HIP")
def test_hsvd_rank_part2(self):
# check if hsvd_rank yields correct results for maxrank <= truerank
# this needs to be skipped on AMD because generation of test data relies on QR...
nprocs = MPI.COMM_WORLD.Get_size()
true_rk = max(10, nprocs)
test_matrices_low_rank = [
ht.utils.data.matrixgallery.random_known_rank(
50, 15 * nprocs, true_rk, split=1, dtype=ht.float32
),
ht.utils.data.matrixgallery.random_known_rank(
50, 15 * nprocs, true_rk, split=1, dtype=ht.float32
),
ht.utils.data.matrixgallery.random_known_rank(
15 * nprocs, 50, true_rk, split=0, dtype=ht.float64
),
ht.utils.data.matrixgallery.random_known_rank(
15 * nprocs, 50, true_rk, split=0, dtype=ht.float64
),
]

for mat in test_matrices_low_rank:
A = mat[0]
if A.dtype == ht.float64:
dtype_tol = 1e-8
if A.dtype == ht.float32:
dtype_tol = 1e-3

for r in [true_rk, true_rk + 2]:
U, s, V, _ = ht.linalg.hsvd_rank(A, r, compute_sv=True)
V = V[:, :true_rk].resplit(V.split)
U = U[:, :true_rk].resplit(U.split)
s = s[:true_rk]

U_orth_err = (
ht.norm(
U.T @ U - ht.eye(true_rk, dtype=U.dtype, split=U.T.split, device=U.device)
)
/ true_rk**0.5
)
V_orth_err = (
ht.norm(
V.T @ V - ht.eye(true_rk, dtype=V.dtype, split=V.T.split, device=V.device)
)
/ true_rk**0.5
)
true_rel_err = ht.norm(U @ ht.diag(s) @ V.T - A) / ht.norm(A)

self.assertTrue(ht.norm(s - mat[1][1]) / ht.norm(mat[1][1]) <= dtype_tol)
self.assertTrue(U_orth_err <= dtype_tol)
self.assertTrue(V_orth_err <= dtype_tol)
self.assertTrue(true_rel_err <= dtype_tol)
Loading

0 comments on commit e882c56

Please sign in to comment.