Skip to content

Commit 20d418b

Browse files
test_dparray now uses pytest instead of unittest
1 parent cbe8ca6 commit 20d418b

File tree

1 file changed

+79
-60
lines changed

1 file changed

+79
-60
lines changed

dpctl/tests/test_dparray.py

Lines changed: 79 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,85 +17,104 @@
1717
"""Unit test cases for dpctl.tensor.numpy_usm_shared.
1818
"""
1919

20-
import unittest
21-
2220
import numpy
2321

2422
from dpctl.tensor import numpy_usm_shared as dparray
2523

2624

27-
class Test_dparray(unittest.TestCase):
28-
def setUp(self):
29-
self.X = dparray.ndarray((256, 4), dtype="d")
30-
self.X.fill(1.0)
25+
def get_arg():
26+
X = dparray.ndarray((256, 4), dtype="d")
27+
X.fill(1.0)
28+
return X
29+
30+
31+
def test_dparray_type():
32+
X = get_arg()
33+
assert isinstance(X, dparray.ndarray)
34+
35+
36+
def test_dparray_as_ndarray_self():
37+
X = get_arg()
38+
Y = X.as_ndarray()
39+
assert type(Y) == numpy.ndarray
40+
41+
42+
def test_dparray_as_ndarray():
43+
X = get_arg()
44+
Y = dparray.as_ndarray(X)
45+
assert type(Y) == numpy.ndarray
46+
47+
48+
def test_dparray_from_ndarray():
49+
X = get_arg()
50+
Y = dparray.as_ndarray(X)
51+
dp1 = dparray.from_ndarray(Y)
52+
assert isinstance(dp1, dparray.ndarray)
53+
54+
55+
def test_multiplication_dparray():
56+
C = get_arg() * 5
57+
assert isinstance(C, dparray.ndarray)
58+
59+
60+
def test_inplace_sub():
61+
X = get_arg()
62+
X -= 1
3163

32-
def test_dparray_type(self):
33-
self.assertIsInstance(self.X, dparray.ndarray)
3464

35-
def test_dparray_as_ndarray_self(self):
36-
Y = self.X.as_ndarray()
37-
self.assertEqual(type(Y), numpy.ndarray)
65+
def test_dparray_through_python_func():
66+
def func_operation_with_const(dpctl_array):
67+
return dpctl_array * 2.0 + 13
3868

39-
def test_dparray_as_ndarray(self):
40-
Y = dparray.as_ndarray(self.X)
41-
self.assertEqual(type(Y), numpy.ndarray)
69+
C = get_arg() * 5
70+
dp_func = func_operation_with_const(C)
71+
assert isinstance(dp_func, dparray.ndarray)
4272

43-
def test_dparray_from_ndarray(self):
44-
Y = dparray.as_ndarray(self.X)
45-
dp1 = dparray.from_ndarray(Y)
46-
self.assertIsInstance(dp1, dparray.ndarray)
4773

48-
def test_multiplication_dparray(self):
49-
C = self.X * 5
50-
self.assertIsInstance(C, dparray.ndarray)
74+
def test_dparray_mixing_dpctl_and_numpy():
75+
dp_numpy = numpy.ones((256, 4), dtype="d")
76+
X = get_arg()
77+
res = dp_numpy * X
78+
assert isinstance(X, dparray.ndarray)
79+
assert isinstance(res, dparray.ndarray)
5180

52-
def test_inplace_sub(self):
53-
self.X -= 1
5481

55-
def test_dparray_through_python_func(self):
56-
def func_operation_with_const(dpctl_array):
57-
return dpctl_array * 2.0 + 13
82+
def test_dparray_shape():
83+
X = get_arg()
84+
res = X.shape
85+
assert res == (256, 4)
5886

59-
C = self.X * 5
60-
dp_func = func_operation_with_const(C)
61-
self.assertIsInstance(dp_func, dparray.ndarray)
6287

63-
def test_dparray_mixing_dpctl_and_numpy(self):
64-
dp_numpy = numpy.ones((256, 4), dtype="d")
65-
res = dp_numpy * self.X
66-
self.assertIsInstance(self.X, dparray.ndarray)
67-
self.assertIsInstance(res, dparray.ndarray)
88+
def test_dparray_T():
89+
X = get_arg()
90+
res = X.T
91+
assert res.shape == (4, 256)
6892

69-
def test_dparray_shape(self):
70-
res = self.X.shape
71-
self.assertEqual(res, (256, 4))
7293

73-
def test_dparray_T(self):
74-
res = self.X.T
75-
self.assertEqual(res.shape, (4, 256))
94+
def test_numpy_ravel_with_dparray():
95+
X = get_arg()
96+
res = numpy.ravel(X)
97+
assert res.shape == (1024,)
7698

77-
def test_numpy_ravel_with_dparray(self):
78-
res = numpy.ravel(self.X)
79-
self.assertEqual(res.shape, (1024,))
8099

81-
def test_numpy_sum_with_dparray(self):
82-
res = numpy.sum(self.X)
83-
self.assertEqual(res, 1024.0)
100+
def test_numpy_sum_with_dparray():
101+
X = get_arg()
102+
res = numpy.sum(X)
103+
assert res == 1024.0
84104

85-
def test_numpy_sum_with_dparray_out(self):
86-
res = dparray.empty((self.X.shape[1],), dtype=self.X.dtype)
87-
res2 = numpy.sum(self.X, axis=0, out=res)
88-
self.assertTrue(res is res2)
89-
self.assertIsInstance(res2, dparray.ndarray)
90105

91-
def test_frexp_with_out(self):
92-
X = dparray.array([0.5, 4.7])
93-
mant = dparray.empty((2,), dtype="d")
94-
exp = dparray.empty((2,), dtype="i4")
95-
res = numpy.frexp(X, out=(mant, exp))
96-
self.assertTrue(res[0] is mant)
97-
self.assertTrue(res[1] is exp)
106+
def test_numpy_sum_with_dparray_out():
107+
X = get_arg()
108+
res = dparray.empty((X.shape[1],), dtype=X.dtype)
109+
res2 = numpy.sum(X, axis=0, out=res)
110+
assert res is res2
111+
assert isinstance(res2, dparray.ndarray)
98112

99113

100-
if __name__ == "__main__":
101-
unittest.main()
114+
def test_frexp_with_out():
115+
X = dparray.array([0.5, 4.7])
116+
mant = dparray.empty((2,), dtype="d")
117+
exp = dparray.empty((2,), dtype="i4")
118+
res = numpy.frexp(X, out=(mant, exp))
119+
assert res[0] is mant
120+
assert res[1] is exp

0 commit comments

Comments
 (0)