@@ -1127,9 +1127,34 @@ def tril(X, k=0):
1127
1127
if type (X ) is not dpt .usm_ndarray :
1128
1128
raise TypeError
1129
1129
1130
- res = dpt .empty (X .shape , dtype = X .dtype , sycl_queue = X .sycl_queue )
1131
- hev , _ = ti ._tril (sycl_queue = X .sycl_queue , src = X , dst = res , k = k )
1132
- hev .wait ()
1130
+ k = operator .index (k )
1131
+
1132
+ # F_CONTIGUOUS = 2
1133
+ order = "f" if (X .flags & 2 ) else "c"
1134
+
1135
+ shape = X .shape
1136
+ nd = X .ndim
1137
+ if nd < 2 :
1138
+ raise ValueError ("Array dimensions less than 2." )
1139
+
1140
+ if k >= shape [nd - 1 ] - 1 :
1141
+ res = dpt .empty (
1142
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1143
+ )
1144
+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
1145
+ src = X , dst = res , sycl_queue = X .sycl_queue
1146
+ )
1147
+ hev .wait ()
1148
+ elif k < - shape [nd - 2 ]:
1149
+ res = dpt .zeros (
1150
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1151
+ )
1152
+ else :
1153
+ res = dpt .empty (
1154
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1155
+ )
1156
+ hev , _ = ti ._tril (src = X , dst = res , k = k , sycl_queue = X .sycl_queue )
1157
+ hev .wait ()
1133
1158
1134
1159
return res
1135
1160
@@ -1143,8 +1168,33 @@ def triu(X, k=0):
1143
1168
if type (X ) is not dpt .usm_ndarray :
1144
1169
raise TypeError
1145
1170
1146
- res = dpt .empty (X .shape , dtype = X .dtype , sycl_queue = X .sycl_queue )
1147
- hev , _ = ti ._triu (sycl_queue = X .sycl_queue , src = X , dst = res , k = k )
1148
- hev .wait ()
1171
+ k = operator .index (k )
1172
+
1173
+ # F_CONTIGUOUS = 2
1174
+ order = "f" if (X .flags & 2 ) else "c"
1175
+
1176
+ shape = X .shape
1177
+ nd = X .ndim
1178
+ if nd < 2 :
1179
+ raise ValueError ("Array dimensions less than 2." )
1180
+
1181
+ if k > shape [nd - 1 ]:
1182
+ res = dpt .zeros (
1183
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1184
+ )
1185
+ elif k <= - shape [nd - 2 ] + 1 :
1186
+ res = dpt .empty (
1187
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1188
+ )
1189
+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
1190
+ src = X , dst = res , sycl_queue = X .sycl_queue
1191
+ )
1192
+ hev .wait ()
1193
+ else :
1194
+ res = dpt .empty (
1195
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1196
+ )
1197
+ hev , _ = ti ._triu (src = X , dst = res , k = k , sycl_queue = X .sycl_queue )
1198
+ hev .wait ()
1149
1199
1150
1200
return res
0 commit comments