Skip to content

Commit ef5f615

Browse files
committed
Added special case for order=f
1 parent 82b4186 commit ef5f615

File tree

1 file changed

+56
-6
lines changed

1 file changed

+56
-6
lines changed

dpctl/tensor/_ctors.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,9 +1127,34 @@ def tril(X, k=0):
11271127
if type(X) is not dpt.usm_ndarray:
11281128
raise TypeError
11291129

1130-
res = dpt.empty(X.shape, dtype=X.dtype, sycl_queue=X.sycl_queue)
1131-
hev, _ = ti._tril(sycl_queue=X.sycl_queue, src=X, dst=res, k=k)
1132-
hev.wait()
1130+
k = operator.index(k)
1131+
1132+
# F_CONTIGUOUS = 2
1133+
order = "f" if (X.flags & 2) else "c"
1134+
1135+
shape = X.shape
1136+
nd = X.ndim
1137+
if nd < 2:
1138+
raise ValueError("Array dimensions less than 2.")
1139+
1140+
if k >= shape[nd - 1] - 1:
1141+
res = dpt.empty(
1142+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1143+
)
1144+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
1145+
src=X, dst=res, sycl_queue=X.sycl_queue
1146+
)
1147+
hev.wait()
1148+
elif k < -shape[nd - 2]:
1149+
res = dpt.zeros(
1150+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1151+
)
1152+
else:
1153+
res = dpt.empty(
1154+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1155+
)
1156+
hev, _ = ti._tril(src=X, dst=res, k=k, sycl_queue=X.sycl_queue)
1157+
hev.wait()
11331158

11341159
return res
11351160

@@ -1143,8 +1168,33 @@ def triu(X, k=0):
11431168
if type(X) is not dpt.usm_ndarray:
11441169
raise TypeError
11451170

1146-
res = dpt.empty(X.shape, dtype=X.dtype, sycl_queue=X.sycl_queue)
1147-
hev, _ = ti._triu(sycl_queue=X.sycl_queue, src=X, dst=res, k=k)
1148-
hev.wait()
1171+
k = operator.index(k)
1172+
1173+
# F_CONTIGUOUS = 2
1174+
order = "f" if (X.flags & 2) else "c"
1175+
1176+
shape = X.shape
1177+
nd = X.ndim
1178+
if nd < 2:
1179+
raise ValueError("Array dimensions less than 2.")
1180+
1181+
if k > shape[nd - 1]:
1182+
res = dpt.zeros(
1183+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1184+
)
1185+
elif k <= -shape[nd - 2] + 1:
1186+
res = dpt.empty(
1187+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1188+
)
1189+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
1190+
src=X, dst=res, sycl_queue=X.sycl_queue
1191+
)
1192+
hev.wait()
1193+
else:
1194+
res = dpt.empty(
1195+
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1196+
)
1197+
hev, _ = ti._triu(src=X, dst=res, k=k, sycl_queue=X.sycl_queue)
1198+
hev.wait()
11491199

11501200
return res

0 commit comments

Comments
 (0)