Skip to content

Commit a197a1d

Browse files
Merge pull request #910 from IntelPython/added_tril_and_triu
Added dpctl.tensor.tril and dpctl.tensor.triu feature
2 parents 4344c83 + b9c560a commit a197a1d

File tree

4 files changed

+579
-0
lines changed

4 files changed

+579
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
linspace,
3434
ones,
3535
ones_like,
36+
tril,
37+
triu,
3638
zeros,
3739
zeros_like,
3840
)
@@ -83,4 +85,6 @@
8385
"to_numpy",
8486
"asnumpy",
8587
"from_dlpack",
88+
"tril",
89+
"triu",
8690
]

dpctl/tensor/_ctors.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,3 +1116,85 @@ def eye(
11161116
hev, _ = ti._eye(k, dst=res, sycl_queue=sycl_queue)
11171117
hev.wait()
11181118
return res
1119+
1120+
1121+
def tril(X, k=0):
1122+
"""
1123+
tril(X: usm_ndarray, k: int) -> usm_ndarray
1124+
1125+
Returns the lower triangular part of a matrix (or a stack of matrices) X.
1126+
"""
1127+
if type(X) is not dpt.usm_ndarray:
1128+
raise TypeError
1129+
1130+
k = operator.index(k)
1131+
1132+
# F_CONTIGUOUS = 2
1133+
order = "F" if (X.flags & 2) else "C"
1134+
1135+
shape = X.shape
1136+
nd = X.ndim
1137+
if nd < 2:
1138+
raise ValueError("Array dimensions less than 2.")
1139+
1140+
if k >= shape[nd - 1] - 1:
1141+
res = dpt.empty(
1142+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1143+
)
1144+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
1145+
src=X, dst=res, sycl_queue=X.sycl_queue
1146+
)
1147+
hev.wait()
1148+
elif k < -shape[nd - 2]:
1149+
res = dpt.zeros(
1150+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1151+
)
1152+
else:
1153+
res = dpt.empty(
1154+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1155+
)
1156+
hev, _ = ti._tril(src=X, dst=res, k=k, sycl_queue=X.sycl_queue)
1157+
hev.wait()
1158+
1159+
return res
1160+
1161+
1162+
def triu(X, k=0):
1163+
"""
1164+
triu(X: usm_ndarray, k: int) -> usm_ndarray
1165+
1166+
Returns the upper triangular part of a matrix (or a stack of matrices) X.
1167+
"""
1168+
if type(X) is not dpt.usm_ndarray:
1169+
raise TypeError
1170+
1171+
k = operator.index(k)
1172+
1173+
# F_CONTIGUOUS = 2
1174+
order = "F" if (X.flags & 2) else "C"
1175+
1176+
shape = X.shape
1177+
nd = X.ndim
1178+
if nd < 2:
1179+
raise ValueError("Array dimensions less than 2.")
1180+
1181+
if k > shape[nd - 1]:
1182+
res = dpt.zeros(
1183+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1184+
)
1185+
elif k <= -shape[nd - 2] + 1:
1186+
res = dpt.empty(
1187+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1188+
)
1189+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
1190+
src=X, dst=res, sycl_queue=X.sycl_queue
1191+
)
1192+
hev.wait()
1193+
else:
1194+
res = dpt.empty(
1195+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1196+
)
1197+
hev, _ = ti._triu(src=X, dst=res, k=k, sycl_queue=X.sycl_queue)
1198+
hev.wait()
1199+
1200+
return res

0 commit comments

Comments
 (0)