Skip to content

Commit

Permalink
Merge pull request #488 from helmholtz-analytics/enhancement/select-d…
Browse files Browse the repository at this point in the history
…evice

Enhancement/select device
  • Loading branch information
Markus-Goetz authored Jun 16, 2020
2 parents fc77d90 + 01eccf1 commit 42bf442
Show file tree
Hide file tree
Showing 35 changed files with 2,194 additions and 2,389 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Pending Additions

- [#488](https://github.com/helmholtz-analytics/heat/pull/488) Enhancement: Rework of the test device selection.
- [#573](https://github.com/helmholtz-analytics/heat/pull/573) Bugfix: matmul fixes: early out for 2 vectors, remainders not added if inner block is 1 for split 10 case
- [#575](https://github.com/helmholtz-analytics/heat/pull/558) Bugfix: Binary operations use proper type casting
- [#575](https://github.com/helmholtz-analytics/heat/pull/558) Bugfix: `where` and `cov` convert ints to floats when given as parameters
Expand Down
14 changes: 2 additions & 12 deletions heat/cluster/tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,10 @@

import heat as ht

if os.environ.get("DEVICE") == "gpu" and ht.torch.cuda.is_available():
ht.use_device("gpu")
ht.torch.cuda.set_device(ht.torch.device(ht.get_device().torch_device))
else:
ht.use_device("cpu")
device = ht.get_device().torch_device
ht_device = None
if os.environ.get("DEVICE") == "lgpu" and ht.torch.cuda.is_available():
device = ht.gpu.torch_device
ht_device = ht.gpu
ht.torch.cuda.set_device(device)
from ...core.tests.test_suites.basic_test import TestCase


class TestKMeans(unittest.TestCase):
class TestKMeans(TestCase):
def test_clusterer(self):
kmeans = ht.cluster.KMeans()
self.assertTrue(ht.is_estimator(kmeans))
Expand Down
14 changes: 2 additions & 12 deletions heat/cluster/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,10 @@

import heat as ht

if os.environ.get("DEVICE") == "gpu" and ht.torch.cuda.is_available():
ht.use_device("gpu")
ht.torch.cuda.set_device(ht.torch.device(ht.get_device().torch_device))
else:
ht.use_device("cpu")
device = ht.get_device().torch_device
ht_device = None
if os.environ.get("DEVICE") == "lgpu" and ht.torch.cuda.is_available():
device = ht.gpu.torch_device
ht_device = ht.gpu
ht.torch.cuda.set_device(device)
from ...core.tests.test_suites.basic_test import TestCase


class TestSpectral(unittest.TestCase):
class TestSpectral(TestCase):
def test_clusterer(self):
spectral = ht.cluster.Spectral()
self.assertTrue(ht.is_estimator(spectral))
Expand Down
460 changes: 225 additions & 235 deletions heat/core/linalg/tests/test_basics.py

Large diffs are not rendered by default.

112 changes: 32 additions & 80 deletions heat/core/linalg/tests/test_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,7 @@
import unittest
import warnings

if os.environ.get("DEVICE") == "gpu" and torch.cuda.is_available():
ht.use_device("gpu")
torch.cuda.set_device(torch.device(ht.get_device().torch_device))
else:
ht.use_device("cpu")
device = ht.get_device().torch_device
ht_device = None
if os.environ.get("DEVICE") == "lgpu" and torch.cuda.is_available():
device = ht.gpu.torch_device
ht_device = ht.gpu
torch.cuda.set_device(device)
from ...tests.test_suites.basic_test import TestCase

if os.environ.get("EXTENDED_TESTS"):
extended_tests = True
Expand All @@ -24,122 +14,84 @@
extended_tests = False


class TestQR(unittest.TestCase):
class TestQR(TestCase):
@unittest.skipIf(not extended_tests, "extended tests")
def test_qr_sp0_ext(self):
st_whole = torch.randn(70, 70, device=device)
st_whole = torch.randn(70, 70, device=self.device.torch_device)
sp = 0
for m in range(50, st_whole.shape[0] + 1, 1):
for n in range(50, st_whole.shape[1] + 1, 1):
for t in range(1, 3):
st = st_whole[:m, :n].clone()
a_comp = ht.array(st, split=0, device=ht_device)
a = ht.array(st, split=sp, device=ht_device)
a_comp = ht.array(st, split=0)
a = ht.array(st, split=sp)
qr = a.qr(tiles_per_proc=t)
self.assertTrue(ht.allclose(a_comp, qr.Q @ qr.R, rtol=1e-5, atol=1e-5))
self.assertTrue(
ht.allclose(
qr.Q.T @ qr.Q, ht.eye(m, device=ht_device), rtol=1e-5, atol=1e-5
)
)
self.assertTrue(
ht.allclose(
ht.eye(m, device=ht_device), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5
)
)
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5))

@unittest.skipIf(not extended_tests, "extended tests")
def test_qr_sp1_ext(self):
st_whole = torch.randn(70, 70, device=device)
st_whole = torch.randn(70, 70, device=self.device.torch_device)
sp = 1
for m in range(50, st_whole.shape[0] + 1, 1):
for n in range(50, st_whole.shape[1] + 1, 1):
for t in range(1, 3):
st = st_whole[:m, :n].clone()
a_comp = ht.array(st, split=0, device=ht_device)
a = ht.array(st, split=sp, device=ht_device)
a_comp = ht.array(st, split=0)
a = ht.array(st, split=sp)
qr = a.qr(tiles_per_proc=t)
self.assertTrue(ht.allclose(a_comp, qr.Q @ qr.R, rtol=1e-5, atol=1e-5))
self.assertTrue(
ht.allclose(
qr.Q.T @ qr.Q, ht.eye(m, device=ht_device), rtol=1e-5, atol=1e-5
)
)
self.assertTrue(
ht.allclose(
ht.eye(m, device=ht_device), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5
)
)
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5))

def test_qr(self):
m, n = 20, 40
st = torch.randn(m, n, device=device, dtype=torch.float)
a_comp = ht.array(st, split=0, device=ht_device)
st = torch.randn(m, n, device=self.device.torch_device, dtype=torch.float)
a_comp = ht.array(st, split=0)
for t in range(1, 3):
for sp in range(2):
a = ht.array(st, split=sp, device=ht_device, dtype=torch.float)
a = ht.array(st, split=sp, dtype=torch.float)
qr = a.qr(tiles_per_proc=t)
self.assertTrue(ht.allclose((a_comp - (qr.Q @ qr.R)), 0, rtol=1e-5, atol=1e-5))
self.assertTrue(
ht.allclose(qr.Q.T @ qr.Q, ht.eye(m, device=ht_device), rtol=1e-5, atol=1e-5)
)
self.assertTrue(
ht.allclose(ht.eye(m, device=ht_device), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5)
)
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5))
m, n = 40, 40
st1 = torch.randn(m, n, device=device)
a_comp1 = ht.array(st1, split=0, device=ht_device)
st1 = torch.randn(m, n, device=self.device.torch_device)
a_comp1 = ht.array(st1, split=0)
for t in range(1, 3):
for sp in range(2):
a1 = ht.array(st1, split=sp, device=ht_device)
a1 = ht.array(st1, split=sp)
qr1 = a1.qr(tiles_per_proc=t)
self.assertTrue(ht.allclose((a_comp1 - (qr1.Q @ qr1.R)), 0, rtol=1e-5, atol=1e-5))
self.assertTrue(
ht.allclose(qr1.Q.T @ qr1.Q, ht.eye(m, device=ht_device), rtol=1e-5, atol=1e-5)
)
self.assertTrue(
ht.allclose(ht.eye(m, device=ht_device), qr1.Q @ qr1.Q.T, rtol=1e-5, atol=1e-5)
)
self.assertTrue(ht.allclose(qr1.Q.T @ qr1.Q, ht.eye(m), rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(ht.eye(m), qr1.Q @ qr1.Q.T, rtol=1e-5, atol=1e-5))
m, n = 40, 20
st2 = torch.randn(m, n, dtype=torch.double, device=device)
a_comp2 = ht.array(st2, split=0, dtype=ht.double, device=ht_device)
st2 = torch.randn(m, n, dtype=torch.double, device=self.device.torch_device)
a_comp2 = ht.array(st2, split=0, dtype=ht.double)
for t in range(1, 3):
for sp in range(2):
a2 = ht.array(st2, split=sp, device=ht_device)
a2 = ht.array(st2, split=sp)
qr2 = a2.qr(tiles_per_proc=t)
self.assertTrue(ht.allclose(a_comp2, qr2.Q @ qr2.R, rtol=1e-5, atol=1e-5))
self.assertTrue(
ht.allclose(
qr2.Q.T @ qr2.Q,
ht.eye(m, dtype=ht.double, device=ht_device),
rtol=1e-5,
atol=1e-5,
)
ht.allclose(qr2.Q.T @ qr2.Q, ht.eye(m, dtype=ht.double), rtol=1e-5, atol=1e-5)
)
self.assertTrue(
ht.allclose(
ht.eye(m, dtype=ht.double, device=ht_device),
qr2.Q @ qr2.Q.T,
rtol=1e-5,
atol=1e-5,
)
ht.allclose(ht.eye(m, dtype=ht.double), qr2.Q @ qr2.Q.T, rtol=1e-5, atol=1e-5)
)
# test if calc R alone works
qr = ht.qr(a2, calc_q=False, overwrite_a=True)
self.assertTrue(qr.Q is None)

m, n = 40, 20
st = torch.randn(m, n, device=device)
a_comp = ht.array(st, split=None, device=ht_device)
a = ht.array(st, split=None, device=ht_device)
st = torch.randn(m, n, device=self.device.torch_device)
a_comp = ht.array(st, split=None)
a = ht.array(st, split=None)
qr = a.qr()
self.assertTrue(ht.allclose(a_comp, qr.Q @ qr.R, rtol=1e-5, atol=1e-5))
self.assertTrue(
ht.allclose(qr.Q.T @ qr.Q, ht.eye(m, device=ht_device), rtol=1e-5, atol=1e-5)
)
self.assertTrue(
ht.allclose(ht.eye(m, device=ht_device), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5)
)
self.assertTrue(ht.allclose(qr.Q.T @ qr.Q, ht.eye(m), rtol=1e-5, atol=1e-5))
self.assertTrue(ht.allclose(ht.eye(m), qr.Q @ qr.Q.T, rtol=1e-5, atol=1e-5))

# raises
with self.assertRaises(TypeError):
Expand Down
14 changes: 2 additions & 12 deletions heat/core/linalg/tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,10 @@
import heat as ht
import numpy as np

if os.environ.get("DEVICE") == "gpu" and torch.cuda.is_available():
ht.use_device("gpu")
torch.cuda.set_device(torch.device(ht.get_device().torch_device))
else:
ht.use_device("cpu")
device = ht.get_device().torch_device
ht_device = None
if os.environ.get("DEVICE") == "lgpu" and torch.cuda.is_available():
device = ht.gpu.torch_device
ht_device = ht.gpu
torch.cuda.set_device(device)
from ...tests.test_suites.basic_test import TestCase


class TestSolver(unittest.TestCase):
class TestSolver(TestCase):
def test_cg(self):
size = ht.communication.MPI_WORLD.size * 3
b = ht.arange(1, size + 1, dtype=ht.float32, split=0)
Expand Down
2 changes: 1 addition & 1 deletion heat/core/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __binary_op(operation, t1, t2):
output_device = None
output_comm = MPI_WORLD
elif isinstance(t2, dndarray.DNDarray):
t1.gpu() if t2.device.device_type == "gpu" else t1.cpu()
t1 = t1.gpu() if t2.device.device_type == "gpu" else t1.cpu()

output_shape = t2.shape
output_split = t2.split
Expand Down
Loading

0 comments on commit 42bf442

Please sign in to comment.