Skip to content

Commit b1a1b50

Browse files
Fixed issues with .real and .imag of usm_ndarray
These were similar to those reported in #649
1 parent e2185e9 commit b1a1b50

File tree

1 file changed

+39
-8
lines changed

1 file changed

+39
-8
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,31 +1014,62 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
10141014
"""
10151015
View into real parts of a complex type array
10161016
"""
1017-
cdef usm_ndarray r = ary._clone()
1017+
cdef int r_typenum_ = -1
1018+
cdef usm_ndarray r = None
1019+
cdef Py_ssize_t offset_elems = 0
1020+
10181021
if (ary.typenum_ == UAR_CFLOAT):
1019-
r.typenum_ = UAR_FLOAT
1022+
r_typenum_ = UAR_FLOAT
10201023
elif (ary.typenum_ == UAR_CDOUBLE):
1021-
r.typenum_ = UAR_DOUBLE
1024+
r_typenum_ = UAR_DOUBLE
10221025
else:
10231026
raise InternalUSMArrayError(
10241027
"_real_view call on array of non-complex type.")
1028+
1029+
offset_elems = ary.get_offset() * 2
1030+
r = usm_ndarray.__new__(
1031+
usm_ndarray,
1032+
_make_int_tuple(ary.nd_, ary.shape_),
1033+
dtype=_make_typestr(r_typenum_),
1034+
strides=tuple(2 * si for si in ary.strides),
1035+
buffer=ary.base_,
1036+
offset=offset_elems,
1037+
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
1038+
)
1039+
r.flags_ = ary.flags_
1040+
r.array_namespace_ = ary.array_namespace_
10251041
return r
10261042

10271043

10281044
cdef usm_ndarray _imag_view(usm_ndarray ary):
10291045
"""
10301046
View into imaginary parts of a complex type array
10311047
"""
1032-
cdef usm_ndarray r = ary._clone()
1048+
cdef int r_typenum_ = -1
1049+
cdef usm_ndarray r = None
1050+
cdef Py_ssize_t offset_elems = 0
1051+
10331052
if (ary.typenum_ == UAR_CFLOAT):
1034-
r.typenum_ = UAR_FLOAT
1053+
r_typenum_ = UAR_FLOAT
10351054
elif (ary.typenum_ == UAR_CDOUBLE):
1036-
r.typenum_ = UAR_DOUBLE
1055+
r_typenum_ = UAR_DOUBLE
10371056
else:
10381057
raise InternalUSMArrayError(
1039-
"_real_view call on array of non-complex type.")
1058+
"_imag_view call on array of non-complex type.")
1059+
10401060
# displace pointer to imaginary part
1041-
r.data_ = r.data_ + type_bytesize(r.typenum_)
1061+
offset_elems = 2 * ary.get_offset() + 1
1062+
r = usm_ndarray.__new__(
1063+
usm_ndarray,
1064+
_make_int_tuple(ary.nd_, ary.shape_),
1065+
dtype=_make_typestr(r_typenum_),
1066+
strides=tuple(2 * si for si in ary.strides),
1067+
buffer=ary.base_,
1068+
offset=offset_elems,
1069+
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
1070+
)
1071+
r.flags_ = ary.flags_
1072+
r.array_namespace_ = ary.array_namespace_
10421073
return r
10431074

10441075

0 commit comments

Comments
 (0)