@@ -1116,3 +1116,85 @@ def eye(
1116
1116
hev , _ = ti ._eye (k , dst = res , sycl_queue = sycl_queue )
1117
1117
hev .wait ()
1118
1118
return res
1119
+
1120
+
1121
+ def tril (X , k = 0 ):
1122
+ """
1123
+ tril(X: usm_ndarray, k: int) -> usm_ndarray
1124
+
1125
+ Returns the lower triangular part of a matrix (or a stack of matrices) X.
1126
+ """
1127
+ if type (X ) is not dpt .usm_ndarray :
1128
+ raise TypeError
1129
+
1130
+ k = operator .index (k )
1131
+
1132
+ # F_CONTIGUOUS = 2
1133
+ order = "F" if (X .flags & 2 ) else "C"
1134
+
1135
+ shape = X .shape
1136
+ nd = X .ndim
1137
+ if nd < 2 :
1138
+ raise ValueError ("Array dimensions less than 2." )
1139
+
1140
+ if k >= shape [nd - 1 ] - 1 :
1141
+ res = dpt .empty (
1142
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1143
+ )
1144
+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
1145
+ src = X , dst = res , sycl_queue = X .sycl_queue
1146
+ )
1147
+ hev .wait ()
1148
+ elif k < - shape [nd - 2 ]:
1149
+ res = dpt .zeros (
1150
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1151
+ )
1152
+ else :
1153
+ res = dpt .empty (
1154
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1155
+ )
1156
+ hev , _ = ti ._tril (src = X , dst = res , k = k , sycl_queue = X .sycl_queue )
1157
+ hev .wait ()
1158
+
1159
+ return res
1160
+
1161
+
1162
+ def triu (X , k = 0 ):
1163
+ """
1164
+ triu(X: usm_ndarray, k: int) -> usm_ndarray
1165
+
1166
+ Returns the upper triangular part of a matrix (or a stack of matrices) X.
1167
+ """
1168
+ if type (X ) is not dpt .usm_ndarray :
1169
+ raise TypeError
1170
+
1171
+ k = operator .index (k )
1172
+
1173
+ # F_CONTIGUOUS = 2
1174
+ order = "F" if (X .flags & 2 ) else "C"
1175
+
1176
+ shape = X .shape
1177
+ nd = X .ndim
1178
+ if nd < 2 :
1179
+ raise ValueError ("Array dimensions less than 2." )
1180
+
1181
+ if k > shape [nd - 1 ]:
1182
+ res = dpt .zeros (
1183
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1184
+ )
1185
+ elif k <= - shape [nd - 2 ] + 1 :
1186
+ res = dpt .empty (
1187
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1188
+ )
1189
+ hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
1190
+ src = X , dst = res , sycl_queue = X .sycl_queue
1191
+ )
1192
+ hev .wait ()
1193
+ else :
1194
+ res = dpt .empty (
1195
+ X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1196
+ )
1197
+ hev , _ = ti ._triu (src = X , dst = res , k = k , sycl_queue = X .sycl_queue )
1198
+ hev .wait ()
1199
+
1200
+ return res
0 commit comments