Skip to content

Commit 9b04a9e

Browse files
Fix/gh 649 transpose (#653)
* Fixes #649 ``` In [1]: import dpctl.tensor as dpt, itertools In [2]: a = dpt.usm_ndarray((2,3)) In [3]: for i,j in itertools.product(range(2), range(3)): a[i, j] = i*3 + j In [4]: dpt.to_numpy(a)[1:].T Out[4]: array([[3.], [4.], [5.]]) In [5]: dpt.to_numpy(a[1:].T) Out[5]: array([[3.], [4.], [5.]]) ``` * Fixed issues with `.real` and `.imag` of usm_ndarray These were similar to those reported in #649 * Added tests for T, real, imag methods of the object inspired by #649
1 parent be5c0d3 commit 9b04a9e

File tree

2 files changed

+63
-9
lines changed

2 files changed

+63
-9
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,31 +1014,62 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
10141014
"""
10151015
View into real parts of a complex type array
10161016
"""
1017-
cdef usm_ndarray r = ary._clone()
1017+
cdef int r_typenum_ = -1
1018+
cdef usm_ndarray r = None
1019+
cdef Py_ssize_t offset_elems = 0
1020+
10181021
if (ary.typenum_ == UAR_CFLOAT):
1019-
r.typenum_ = UAR_FLOAT
1022+
r_typenum_ = UAR_FLOAT
10201023
elif (ary.typenum_ == UAR_CDOUBLE):
1021-
r.typenum_ = UAR_DOUBLE
1024+
r_typenum_ = UAR_DOUBLE
10221025
else:
10231026
raise InternalUSMArrayError(
10241027
"_real_view call on array of non-complex type.")
1028+
1029+
offset_elems = ary.get_offset() * 2
1030+
r = usm_ndarray.__new__(
1031+
usm_ndarray,
1032+
_make_int_tuple(ary.nd_, ary.shape_),
1033+
dtype=_make_typestr(r_typenum_),
1034+
strides=tuple(2 * si for si in ary.strides),
1035+
buffer=ary.base_,
1036+
offset=offset_elems,
1037+
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
1038+
)
1039+
r.flags_ = ary.flags_
1040+
r.array_namespace_ = ary.array_namespace_
10251041
return r
10261042

10271043

10281044
cdef usm_ndarray _imag_view(usm_ndarray ary):
10291045
"""
10301046
View into imaginary parts of a complex type array
10311047
"""
1032-
cdef usm_ndarray r = ary._clone()
1048+
cdef int r_typenum_ = -1
1049+
cdef usm_ndarray r = None
1050+
cdef Py_ssize_t offset_elems = 0
1051+
10331052
if (ary.typenum_ == UAR_CFLOAT):
1034-
r.typenum_ = UAR_FLOAT
1053+
r_typenum_ = UAR_FLOAT
10351054
elif (ary.typenum_ == UAR_CDOUBLE):
1036-
r.typenum_ = UAR_DOUBLE
1055+
r_typenum_ = UAR_DOUBLE
10371056
else:
10381057
raise InternalUSMArrayError(
1039-
"_real_view call on array of non-complex type.")
1058+
"_imag_view call on array of non-complex type.")
1059+
10401060
# displace pointer to imaginary part
1041-
r.data_ = r.data_ + type_bytesize(r.typenum_)
1061+
offset_elems = 2 * ary.get_offset() + 1
1062+
r = usm_ndarray.__new__(
1063+
usm_ndarray,
1064+
_make_int_tuple(ary.nd_, ary.shape_),
1065+
dtype=_make_typestr(r_typenum_),
1066+
strides=tuple(2 * si for si in ary.strides),
1067+
buffer=ary.base_,
1068+
offset=offset_elems,
1069+
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
1070+
)
1071+
r.flags_ = ary.flags_
1072+
r.array_namespace_ = ary.array_namespace_
10421073
return r
10431074

10441075

@@ -1054,7 +1085,8 @@ cdef usm_ndarray _transpose(usm_ndarray ary):
10541085
_make_reversed_int_tuple(ary.nd_, ary.strides_)
10551086
if (ary.strides_) else None),
10561087
buffer=ary.base_,
1057-
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C')
1088+
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
1089+
offset=ary.get_offset()
10581090
)
10591091
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITEABLE)
10601092
return r

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,3 +841,25 @@ def test_reshape():
841841
dpt.reshape(Z, Z.shape, order="invalid")
842842
W = dpt.reshape(Z, (-1,), order="C")
843843
assert W.shape == (Z.size,)
844+
845+
846+
def test_transpose():
847+
n, m = 2, 3
848+
X = dpt.usm_ndarray((n, m), "f4")
849+
Xnp = np.arange(n * m, dtype="f4").reshape((n, m))
850+
X[:] = Xnp
851+
assert np.array_equal(dpt.to_numpy(X.T), Xnp.T)
852+
assert np.array_equal(dpt.to_numpy(X[1:].T), Xnp[1:].T)
853+
854+
855+
def test_real_imag_views():
856+
n, m = 2, 3
857+
X = dpt.usm_ndarray((n, m), "c8")
858+
Xnp_r = np.arange(n * m, dtype="f4").reshape((n, m))
859+
Xnp_i = np.arange(n * m, 2 * n * m, dtype="f4").reshape((n, m))
860+
Xnp = Xnp_r + 1j * Xnp_i
861+
X[:] = Xnp
862+
assert np.array_equal(dpt.to_numpy(X.real), Xnp.real)
863+
assert np.array_equal(dpt.to_numpy(X.imag), Xnp.imag)
864+
assert np.array_equal(dpt.to_numpy(X[1:].real), Xnp[1:].real)
865+
assert np.array_equal(dpt.to_numpy(X[1:].imag), Xnp[1:].imag)

0 commit comments

Comments
 (0)