Skip to content

usm_ndarray.real and usm_ndarray.imag now set flags correctly #1355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ cdef class usm_ndarray:
"""
if self.nd_ < 2:
raise ValueError(
"array.mT requires array to have at least 2-dimensons."
"array.mT requires array to have at least 2 dimensions."
)
return _m_transpose(self)

Expand Down Expand Up @@ -1216,14 +1216,14 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
offset_elems = ary.get_offset() * 2
r = usm_ndarray.__new__(
usm_ndarray,
_make_int_tuple(ary.nd_, ary.shape_),
_make_int_tuple(ary.nd_, ary.shape_) if ary.nd_ > 0 else tuple(),
dtype=_make_typestr(r_typenum_),
strides=tuple(2 * si for si in ary.strides),
buffer=ary.base_,
offset=offset_elems,
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
)
r.flags_ = ary.flags_
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
r.array_namespace_ = ary.array_namespace_
return r

Expand All @@ -1248,14 +1248,14 @@ cdef usm_ndarray _imag_view(usm_ndarray ary):
offset_elems = 2 * ary.get_offset() + 1
r = usm_ndarray.__new__(
usm_ndarray,
_make_int_tuple(ary.nd_, ary.shape_),
_make_int_tuple(ary.nd_, ary.shape_) if ary.nd_ > 0 else tuple(),
dtype=_make_typestr(r_typenum_),
strides=tuple(2 * si for si in ary.strides),
buffer=ary.base_,
offset=offset_elems,
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
)
r.flags_ = ary.flags_
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
r.array_namespace_ = ary.array_namespace_
return r

Expand Down
11 changes: 10 additions & 1 deletion dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,17 +1438,26 @@ def test_real_imag_views():
n, m = 2, 3
try:
X = dpt.usm_ndarray((n, m), "c8")
X_scalar = dpt.usm_ndarray((), dtype="c8")
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Xnp_r = np.arange(n * m, dtype="f4").reshape((n, m))
Xnp_i = np.arange(n * m, 2 * n * m, dtype="f4").reshape((n, m))
Xnp = Xnp_r + 1j * Xnp_i
X[:] = Xnp
assert np.array_equal(dpt.to_numpy(X.real), Xnp.real)
X_real = X.real
X_imag = X.imag
assert np.array_equal(dpt.to_numpy(X_real), Xnp.real)
assert np.array_equal(dpt.to_numpy(X.imag), Xnp.imag)
assert not X_real.flags["C"] and not X_real.flags["F"]
assert not X_imag.flags["C"] and not X_imag.flags["F"]
assert X_real.strides == X_imag.strides
assert np.array_equal(dpt.to_numpy(X[1:].real), Xnp[1:].real)
assert np.array_equal(dpt.to_numpy(X[1:].imag), Xnp[1:].imag)

X_scalar[...] = complex(n * m, 2 * n * m)
assert X_scalar.real and X_scalar.imag


@pytest.mark.parametrize(
"dtype",
Expand Down