Skip to content

Implemented dpctl.tensor.meshgrid and tests #920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
full,
full_like,
linspace,
meshgrid,
ones,
ones_like,
tril,
Expand Down Expand Up @@ -87,4 +88,5 @@
"from_dlpack",
"tril",
"triu",
"meshgrid",
]
58 changes: 58 additions & 0 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,3 +1198,61 @@ def triu(X, k=0):
hev.wait()

return res


def meshgrid(*arrays, indexing="xy"):

"""
meshgrid(*arrays, indexing="xy") -> list[usm_ndarray]

Creates list of `usm_ndarray` coordinate matrices from vectors.

Args:
arrays: arbitrary number of one-dimensional `USM_ndarray` objects.
If vectors are not of the same data type,
or are not one-dimensional, raises `ValueError.`
indexing: Cartesian (`xy`) or matrix (`ij`) indexing of output.
For a set of `n` vectors with lengths N0, N1, N2, ...
Cartesian indexing results in arrays of shape
(N1, N0, N2, ...)
matrix indexing results in arrays of shape
(n0, N1, N2, ...)
Default: `xy`.
"""
ref_dt = None
ref_unset = True
for array in arrays:
if not isinstance(array, dpt.usm_ndarray):
raise TypeError(
f"Expected instance of dpt.usm_ndarray, got {type(array)}."
)
if array.ndim != 1:
raise ValueError("All arrays must be one-dimensional.")
if ref_unset:
ref_unset = False
ref_dt = array.dtype
else:
if not ref_dt == array.dtype:
raise ValueError(
"All arrays must be of the same numeric data type."
)
if indexing not in ["xy", "ij"]:
raise ValueError(
"Unrecognized indexing keyword value, expecting 'xy' or 'ij.'"
)
n = len(arrays)
sh = (-1,) + (1,) * (n - 1)

res = []
if n > 1 and indexing == "xy":
res.append(dpt.reshape(arrays[0], (1, -1) + sh[2:], copy=True))
res.append(dpt.reshape(arrays[1], sh, copy=True))
arrays, sh = arrays[2:], sh[-2:] + sh[:-2]

for array in arrays:
res.append(dpt.reshape(array, sh, copy=True))
sh = sh[-1:] + sh[:-1]

output = dpt.broadcast_arrays(*res)

return output
53 changes: 53 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,57 @@ def test_tril_order_k(order, k):
assert np.array_equal(Ynp, dpt.asnumpy(Y))


def test_meshgrid():
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")
X = dpt.arange(5, sycl_queue=q)
Y = dpt.arange(3, sycl_queue=q)
Z = dpt.meshgrid(X, Y)
Znp = np.meshgrid(dpt.asnumpy(X), dpt.asnumpy(Y))
n = len(Z)
assert n == len(Znp)
for i in range(n):
assert np.array_equal(dpt.asnumpy(Z[i]), Znp[i])
# dimension > 1 must raise ValueError
with pytest.raises(ValueError):
dpt.meshgrid(dpt.usm_ndarray((4, 4)))
# unknown indexing kwarg must raise ValueError
with pytest.raises(ValueError):
dpt.meshgrid(X, indexing="ji")
# input arrays with different data types must raise ValueError
with pytest.raises(ValueError):
dpt.meshgrid(X, dpt.asarray(Y, dtype="b1"))


def test_meshgrid2():
try:
q1 = dpctl.SyclQueue()
q2 = dpctl.SyclQueue()
q3 = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")
x1 = dpt.arange(0, 2, dtype="int16", sycl_queue=q1)
x2 = dpt.arange(3, 6, dtype="int16", sycl_queue=q2)
x3 = dpt.arange(6, 10, dtype="int16", sycl_queue=q3)
y1, y2, y3 = dpt.meshgrid(x1, x2, x3, indexing="xy")
z1, z2, z3 = dpt.meshgrid(x1, x2, x3, indexing="ij")
assert all(
x.sycl_queue == y.sycl_queue for x, y in zip((x1, x2, x3), (y1, y2, y3))
)
assert all(
x.sycl_queue == z.sycl_queue for x, z in zip((x1, x2, x3), (z1, z2, z3))
)
assert y1.shape == y2.shape and y2.shape == y3.shape
assert z1.shape == z2.shape and z2.shape == z3.shape
assert y1.shape == (len(x2), len(x1), len(x3))
assert z1.shape == (len(x1), len(x2), len(x3))
# FIXME: uncomment out once gh-921 is merged
# assert all(z.flags["C"] for z in (z1, z2, z3))
# assert all(y.flags["C"] for y in (y1, y2, y3))


def test_common_arg_validation():
order = "I"
# invalid order must raise ValueError
Expand Down Expand Up @@ -1463,3 +1514,5 @@ def test_common_arg_validation():
dpt.tril(X)
with pytest.raises(TypeError):
dpt.triu(X)
with pytest.raises(TypeError):
dpt.meshgrid(X)