@@ -1014,31 +1014,62 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
1014
1014
"""
1015
1015
View into real parts of a complex type array
1016
1016
"""
1017
- cdef usm_ndarray r = ary._clone()
1017
+ cdef int r_typenum_ = - 1
1018
+ cdef usm_ndarray r = None
1019
+ cdef Py_ssize_t offset_elems = 0
1020
+
1018
1021
if (ary.typenum_ == UAR_CFLOAT):
1019
- r.typenum_ = UAR_FLOAT
1022
+ r_typenum_ = UAR_FLOAT
1020
1023
elif (ary.typenum_ == UAR_CDOUBLE):
1021
- r.typenum_ = UAR_DOUBLE
1024
+ r_typenum_ = UAR_DOUBLE
1022
1025
else :
1023
1026
raise InternalUSMArrayError(
1024
1027
" _real_view call on array of non-complex type." )
1028
+
1029
+ offset_elems = ary.get_offset() * 2
1030
+ r = usm_ndarray.__new__ (
1031
+ usm_ndarray,
1032
+ _make_int_tuple(ary.nd_, ary.shape_),
1033
+ dtype = _make_typestr(r_typenum_),
1034
+ strides = tuple (2 * si for si in ary.strides),
1035
+ buffer = ary.base_,
1036
+ offset = offset_elems,
1037
+ order = (' C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' F' )
1038
+ )
1039
+ r.flags_ = ary.flags_
1040
+ r.array_namespace_ = ary.array_namespace_
1025
1041
return r
1026
1042
1027
1043
1028
1044
cdef usm_ndarray _imag_view(usm_ndarray ary):
1029
1045
"""
1030
1046
View into imaginary parts of a complex type array
1031
1047
"""
1032
- cdef usm_ndarray r = ary._clone()
1048
+ cdef int r_typenum_ = - 1
1049
+ cdef usm_ndarray r = None
1050
+ cdef Py_ssize_t offset_elems = 0
1051
+
1033
1052
if (ary.typenum_ == UAR_CFLOAT):
1034
- r.typenum_ = UAR_FLOAT
1053
+ r_typenum_ = UAR_FLOAT
1035
1054
elif (ary.typenum_ == UAR_CDOUBLE):
1036
- r.typenum_ = UAR_DOUBLE
1055
+ r_typenum_ = UAR_DOUBLE
1037
1056
else :
1038
1057
raise InternalUSMArrayError(
1039
- " _real_view call on array of non-complex type." )
1058
+ " _imag_view call on array of non-complex type." )
1059
+
1040
1060
# displace pointer to imaginary part
1041
- r.data_ = r.data_ + type_bytesize(r.typenum_)
1061
+ offset_elems = 2 * ary.get_offset() + 1
1062
+ r = usm_ndarray.__new__ (
1063
+ usm_ndarray,
1064
+ _make_int_tuple(ary.nd_, ary.shape_),
1065
+ dtype = _make_typestr(r_typenum_),
1066
+ strides = tuple (2 * si for si in ary.strides),
1067
+ buffer = ary.base_,
1068
+ offset = offset_elems,
1069
+ order = (' C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' F' )
1070
+ )
1071
+ r.flags_ = ary.flags_
1072
+ r.array_namespace_ = ary.array_namespace_
1042
1073
return r
1043
1074
1044
1075
0 commit comments