Skip to content

Commit a490899

Browse files
committed
modifying dpnp.linalg.det function
1 parent 1fac6e6 commit a490899

File tree

3 files changed

+49
-10
lines changed

3 files changed

+49
-10
lines changed

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,9 @@ cpdef object dpnp_cond(object input, object p):
142142
cpdef utils.dpnp_descriptor dpnp_det(utils.dpnp_descriptor input):
143143
cdef shape_type_c input_shape = input.shape
144144
cdef size_t n = input.shape[-1]
145-
cdef size_t size_out = 1
145+
cdef shape_type_c result_shape = (1,)
146146
if input.ndim != 2:
147-
output_shape = tuple((list(input.shape))[:-2])
148-
for i in range(len(output_shape)):
149-
size_out *= output_shape[i]
150-
151-
cdef shape_type_c result_shape = (size_out,)
152-
if size_out > 1:
153-
result_shape = output_shape
147+
result_shape = tuple((list(input.shape))[:-2])
154148

155149
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
156150

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def det(input):
159159

160160
x1_desc = dpnp.get_dpnp_descriptor(input, copy_when_nondefault_queue=False)
161161
if x1_desc:
162-
if x1_desc.shape[-1] == x1_desc.shape[-2]:
162+
if x1_desc.ndim < 2:
163+
pass
164+
elif x1_desc.shape[-1] == x1_desc.shape[-2]:
163165
result_obj = dpnp_det(x1_desc).get_pyobj()
164166
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
165167

tests/test_linalg.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ def test_det(array):
128128
assert_allclose(expected, result)
129129

130130

131+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
132+
def test_det_empty():
133+
a = numpy.empty((0,0,2,2), dtype=numpy.float32)
134+
ia = inp.array(a)
135+
136+
np_det = numpy.linalg.det(a)
137+
dpnp_det = inp.linalg.det(ia)
138+
139+
assert dpnp_det.dtype == np_det.dtype
140+
assert dpnp_det.shape == np_det.shape
141+
142+
assert_allclose(np_det,dpnp_det)
143+
144+
131145
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
132146
@pytest.mark.parametrize("size", [2, 4, 8, 16, 300])
133147
def test_eig_arange(type, size):
@@ -388,7 +402,7 @@ def test_qr(type, shape, mode):
388402
# check decomposition
389403
assert_allclose(
390404
ia,
391-
numpy.dot(inp.asnumpy(dpnp_q), inp.asnumpy(dpnp_r)),
405+
inp.dot(dpnp_q, dpnp_r),
392406
rtol=tol,
393407
atol=tol,
394408
)
@@ -409,6 +423,35 @@ def test_qr(type, shape, mode):
409423
assert_allclose(dpnp_r, np_r, rtol=tol, atol=tol)
410424

411425

426+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
427+
def test_qr_not_2D():
428+
a = numpy.arange(12, dtype=numpy.float32).reshape((3,2,2))
429+
ia = inp.array(a)
430+
431+
np_q, np_r = numpy.linalg.qr(a)
432+
dpnp_q, dpnp_r = inp.linalg.qr(ia)
433+
434+
assert dpnp_q.dtype == np_q.dtype
435+
assert dpnp_r.dtype == np_r.dtype
436+
assert dpnp_q.shape == np_q.shape
437+
assert dpnp_r.shape == np_r.shape
438+
439+
assert_allclose(ia,inp.matmul(dpnp_q, dpnp_r))
440+
441+
a = numpy.empty((0,3,2), dtype=numpy.float32)
442+
ia = inp.array(a)
443+
444+
np_q, np_r = numpy.linalg.qr(a)
445+
dpnp_q, dpnp_r = inp.linalg.qr(ia)
446+
447+
assert dpnp_q.dtype == np_q.dtype
448+
assert dpnp_r.dtype == np_r.dtype
449+
assert dpnp_q.shape == np_q.shape
450+
assert dpnp_r.shape == np_r.shape
451+
452+
assert_allclose(ia,inp.matmul(dpnp_q, dpnp_r))
453+
454+
412455
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
413456
@pytest.mark.parametrize(
414457
"shape",

0 commit comments

Comments
 (0)