-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhancement/select device #488
Changes from 35 commits
042bea1
590e765
1ae04f4
65012ea
32a9ab5
e3c3ed7
5832958
4d0c63a
b6b9d5e
f75b80c
8eea734
8010878
f1ae63f
fc3b38e
1227183
bd45727
680ad9c
7ac3282
df11692
fc02990
a44f03b
58ec0b1
335d955
768246f
c453aab
816ff85
613e597
9535102
54f3a85
9e8ed07
0af54b1
cfa24c7
04836d7
cbdaa6b
17e7293
2374ae8
bac8ec6
ffd6e7b
b355451
36efc80
c7c596e
a4ca8db
79925b5
093ab69
499cc1e
1b178cd
1ffe420
d9eac6c
f269458
6eb50f1
4e98d25
ff195ee
2d94115
920b973
6fb10f8
9799411
31f2438
fbda362
01eccf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,20 +3,14 @@ | |
|
||
import heat as ht | ||
|
||
if os.environ.get("DEVICE") == "gpu" and ht.torch.cuda.is_available(): | ||
ht.use_device("gpu") | ||
ht.torch.cuda.set_device(ht.torch.device(ht.get_device().torch_device)) | ||
else: | ||
ht.use_device("cpu") | ||
device = ht.get_device().torch_device | ||
ht_device = None | ||
if os.environ.get("DEVICE") == "lgpu" and ht.torch.cuda.is_available(): | ||
device = ht.gpu.torch_device | ||
ht_device = ht.gpu | ||
ht.torch.cuda.set_device(device) | ||
from heat.core.tests.test_suites.basic_test import BasicTest | ||
|
||
|
||
class TestKMeans(unittest.TestCase): | ||
class TestKMeans(BasicTest): | ||
@classmethod | ||
def setUpClass(cls): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be explicitly called here? Should be automatically done due to inheritance There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The ones without already existing setup functions can be removed. Good. |
||
super(TestKMeans, cls).setUpClass() | ||
|
||
def test_clusterer(self): | ||
kmeans = ht.cluster.KMeans() | ||
self.assertTrue(ht.is_estimator(kmeans)) | ||
|
@@ -38,7 +32,9 @@ def test_get_and_set_params(self): | |
def test_fit_iris_unsplit(self): | ||
for split in [None, 0]: | ||
# get some test data | ||
iris = ht.load("heat/datasets/data/iris.csv", sep=";", split=split) | ||
iris = ht.load( | ||
"heat/datasets/data/iris.csv", sep=";", split=split, device=self.ht_device | ||
) | ||
|
||
# fit the clusters | ||
k = 3 | ||
|
@@ -58,7 +54,7 @@ def test_fit_iris_unsplit(self): | |
|
||
def test_exceptions(self): | ||
# get some test data | ||
iris_split = ht.load("heat/datasets/data/iris.csv", sep=";", split=1) | ||
iris_split = ht.load("heat/datasets/data/iris.csv", sep=";", split=1, device=self.ht_device) | ||
|
||
# build a clusterer | ||
k = 3 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably be called TestCase as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be imported as TestCase or should the name be changed in general?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name should be changed in general in my opinion