Skip to content

Commit 1c67f18

Browse files
committed
Add tril and triu function
1 parent 4344c83 commit 1c67f18

File tree

3 files changed

+347
-0
lines changed

3 files changed

+347
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
linspace,
3434
ones,
3535
ones_like,
36+
tril,
37+
triu,
3638
zeros,
3739
zeros_like,
3840
)
@@ -83,4 +85,6 @@
8385
"to_numpy",
8486
"asnumpy",
8587
"from_dlpack",
88+
"tril",
89+
"triu",
8690
]

dpctl/tensor/_ctors.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,3 +1116,35 @@ def eye(
11161116
hev, _ = ti._eye(k, dst=res, sycl_queue=sycl_queue)
11171117
hev.wait()
11181118
return res
1119+
1120+
1121+
def tril(X, k=0):
1122+
"""
1123+
tril(X: usm_ndarray, k: int) -> usm_ndarray
1124+
1125+
Returns the lower triangular part of a matrix (or a stack of matrices) X.
1126+
"""
1127+
if type(X) is not dpt.usm_ndarray:
1128+
raise TypeError
1129+
1130+
res = dpt.empty(X.shape, dtype=X.dtype, sycl_queue=X.sycl_queue)
1131+
hev, _ = ti._tril(sycl_queue=X.sycl_queue, src=X, dst=res, k=k)
1132+
hev.wait()
1133+
1134+
return res
1135+
1136+
1137+
def triu(X, k=0):
1138+
"""
1139+
triu(X: usm_ndarray, k: int) -> usm_ndarray
1140+
1141+
Returns the upper triangular part of a matrix (or a stack of matrices) X.
1142+
"""
1143+
if type(X) is not dpt.usm_ndarray:
1144+
raise TypeError
1145+
1146+
res = dpt.empty(X.shape, dtype=X.dtype, sycl_queue=X.sycl_queue)
1147+
hev, _ = ti._triu(sycl_queue=X.sycl_queue, src=X, dst=res, k=k)
1148+
hev.wait()
1149+
1150+
return res

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,6 +1879,290 @@ eye(py::ssize_t k,
18791879
eye_event);
18801880
}
18811881

1882+
/* =========================== Tril and triu ============================== */
1883+
// define function type
1884+
typedef sycl::event (*tri_fn_ptr_t)(sycl::queue,
1885+
py::ssize_t, // inner_range //py::ssize_t
1886+
py::ssize_t, // outer_range
1887+
char *, // src_data_ptr
1888+
char *, // dst_data_ptr
1889+
py::ssize_t, // nd
1890+
py::ssize_t *, // shape_and_strides
1891+
int, // k
1892+
const std::vector<sycl::event> &,
1893+
const std::vector<sycl::event> &);
1894+
1895+
template <typename Ty, bool> class tri_kernel;
1896+
template <typename Ty, bool l>
1897+
sycl::event tri_impl(sycl::queue exec_q,
1898+
py::ssize_t inner_range,
1899+
py::ssize_t outer_range,
1900+
char *src_p,
1901+
char *dst_p,
1902+
py::ssize_t nd,
1903+
py::ssize_t *shape_and_strides,
1904+
int k,
1905+
const std::vector<sycl::event> &depends,
1906+
const std::vector<sycl::event> &additional_depends)
1907+
{
1908+
constexpr int d2 = 2;
1909+
py::ssize_t src_s = nd;
1910+
py::ssize_t dst_s = 2 * nd;
1911+
py::ssize_t nd_1 = nd - 1;
1912+
py::ssize_t nd_2 = nd - 2;
1913+
Ty *src = reinterpret_cast<Ty *>(src_p);
1914+
Ty *dst = reinterpret_cast<Ty *>(dst_p);
1915+
1916+
sycl::event tri_ev = exec_q.submit([&](sycl::handler &cgh) {
1917+
cgh.depends_on(depends);
1918+
cgh.depends_on(additional_depends);
1919+
cgh.parallel_for<tri_kernel<Ty, l>>(
1920+
sycl::range<2>(inner_range, outer_range), [=](sycl::item<2> idx) {
1921+
py::ssize_t src_inner_offset, dst_inner_offset;
1922+
bool to_copy;
1923+
1924+
{
1925+
py::ssize_t inner_gid = idx.get_id(0);
1926+
CIndexer_array<d2, py::ssize_t> indexer_i(
1927+
{shape_and_strides[nd_2], shape_and_strides[nd_1]});
1928+
indexer_i.set(inner_gid);
1929+
const std::array<py::ssize_t, d2> &inner = indexer_i.get();
1930+
src_inner_offset =
1931+
inner[0] * shape_and_strides[src_s + nd_2] +
1932+
inner[1] * shape_and_strides[src_s + nd_1];
1933+
dst_inner_offset =
1934+
inner[0] * shape_and_strides[dst_s + nd_2] +
1935+
inner[1] * shape_and_strides[dst_s + nd_1];
1936+
1937+
if (l)
1938+
to_copy = (inner[0] + k >= inner[1]);
1939+
else
1940+
to_copy = (inner[0] + k <= inner[1]);
1941+
}
1942+
1943+
py::ssize_t src_offset = 0;
1944+
py::ssize_t dst_offset = 0;
1945+
{
1946+
py::ssize_t outer_gid = idx.get_id(1);
1947+
CIndexer_vector<py::ssize_t> outer(nd - d2);
1948+
outer.get_displacement(
1949+
outer_gid, shape_and_strides, shape_and_strides + src_s,
1950+
shape_and_strides + dst_s, src_offset, dst_offset);
1951+
}
1952+
1953+
src_offset += src_inner_offset;
1954+
dst_offset += dst_inner_offset;
1955+
1956+
dst[dst_offset] = (to_copy) ? src[src_offset] : Ty(0);
1957+
});
1958+
});
1959+
return tri_ev;
1960+
}
1961+
1962+
static tri_fn_ptr_t tril_generic_dispatch_vector[_ns::num_types];
1963+
1964+
template <typename fnT, typename Ty> struct TrilGenericFactory
1965+
{
1966+
fnT get()
1967+
{
1968+
fnT f = tri_impl<Ty, /*tril*/ true>;
1969+
return f;
1970+
}
1971+
};
1972+
1973+
static tri_fn_ptr_t triu_generic_dispatch_vector[_ns::num_types];
1974+
1975+
template <typename fnT, typename Ty> struct TriuGenericFactory
1976+
{
1977+
fnT get()
1978+
{
1979+
fnT f = tri_impl<Ty, /*triu*/ false>;
1980+
return f;
1981+
}
1982+
};
1983+
1984+
std::pair<sycl::event, sycl::event>
1985+
tri(sycl::queue &exec_q,
1986+
dpctl::tensor::usm_ndarray src,
1987+
dpctl::tensor::usm_ndarray dst,
1988+
char part,
1989+
int k = 0,
1990+
const std::vector<sycl::event> &depends = {})
1991+
{
1992+
// array dimensions must be the same
1993+
int src_nd = src.get_ndim();
1994+
int dst_nd = dst.get_ndim();
1995+
if (src_nd != dst_nd) {
1996+
throw py::value_error("Array dimensions are not the same.");
1997+
}
1998+
1999+
if (src_nd < 2) {
2000+
throw py::value_error("Array dimensions less than 2.");
2001+
}
2002+
2003+
// shapes must be the same
2004+
const py::ssize_t *src_shape = src.get_shape_raw();
2005+
const py::ssize_t *dst_shape = dst.get_shape_raw();
2006+
2007+
bool shapes_equal(true);
2008+
size_t src_nelems(1);
2009+
2010+
for (int i = 0; i < src_nd; ++i) {
2011+
src_nelems *= static_cast<size_t>(src_shape[i]);
2012+
shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
2013+
}
2014+
if (!shapes_equal) {
2015+
throw py::value_error("Array shapes are not the same.");
2016+
}
2017+
2018+
if (src_nelems == 0) {
2019+
// nothing to do
2020+
return std::make_pair(sycl::event(), sycl::event());
2021+
}
2022+
2023+
int src_typenum = src.get_typenum();
2024+
int dst_typenum = dst.get_typenum();
2025+
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
2026+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
2027+
if (dst_typeid != src_typeid) {
2028+
throw py::value_error("Array dtype are not the same.");
2029+
}
2030+
2031+
// check same contexts
2032+
sycl::queue src_q = src.get_queue();
2033+
sycl::queue dst_q = dst.get_queue();
2034+
2035+
sycl::context exec_ctx = exec_q.get_context();
2036+
sycl::device exec_d = exec_q.get_device();
2037+
if (src_q.get_context() != exec_ctx || dst_q.get_context() != exec_ctx ||
2038+
src_q.get_device() != exec_d || dst_q.get_device() != exec_d)
2039+
{
2040+
throw py::value_error(
2041+
"Execution queue context is not the same as allocation contexts");
2042+
}
2043+
2044+
using shT = std::vector<py::ssize_t>;
2045+
int src_flags = src.get_flags();
2046+
const py::ssize_t *src_strides_raw = src.get_strides_raw();
2047+
shT src_strides(src_nd);
2048+
bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
2049+
bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
2050+
if (src_strides_raw == nullptr) {
2051+
if (is_src_c_contig) {
2052+
src_strides = c_contiguous_strides(src_nd, src_shape);
2053+
}
2054+
else if (is_src_f_contig) {
2055+
src_strides = f_contiguous_strides(src_nd, src_shape);
2056+
}
2057+
else {
2058+
throw std::runtime_error("Source array has null strides but has "
2059+
"neither C- nor F- contiguous flag set");
2060+
}
2061+
}
2062+
else {
2063+
for (ssize_t i = 0; i < src_nd; i++) {
2064+
src_strides[i] = src_strides_raw[i];
2065+
}
2066+
}
2067+
2068+
int dst_flags = dst.get_flags();
2069+
const py::ssize_t *dst_strides_raw = dst.get_strides_raw();
2070+
shT dst_strides(src_nd);
2071+
bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
2072+
bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
2073+
if (dst_strides_raw == nullptr) {
2074+
if (is_dst_c_contig) {
2075+
dst_strides = c_contiguous_strides(src_nd, src_shape);
2076+
}
2077+
else if (is_dst_f_contig) {
2078+
dst_strides = f_contiguous_strides(src_nd, src_shape);
2079+
}
2080+
else {
2081+
throw std::runtime_error("Source array has null strides but has "
2082+
"neither C- nor F- contiguous flag set");
2083+
}
2084+
}
2085+
else {
2086+
for (ssize_t i = 0; i < src_nd; i++) {
2087+
dst_strides[i] = dst_strides_raw[i];
2088+
}
2089+
}
2090+
2091+
shT simplified_shape;
2092+
shT simplified_src_strides;
2093+
shT simplified_dst_strides;
2094+
py::ssize_t src_offset(0);
2095+
py::ssize_t dst_offset(0);
2096+
2097+
constexpr py::ssize_t src_itemsize = 1; // item size in elements
2098+
constexpr py::ssize_t dst_itemsize = 1; // item size in elements
2099+
2100+
int nd = src_nd - 2;
2101+
const py::ssize_t *shape = src_shape;
2102+
const py::ssize_t *p_src_strides = &src_strides[0];
2103+
const py::ssize_t *p_dst_strides = &dst_strides[0];
2104+
simplify_iteration_space(nd, shape, p_src_strides, src_itemsize,
2105+
is_src_c_contig, is_src_f_contig, p_dst_strides,
2106+
dst_itemsize, is_dst_c_contig, is_dst_f_contig,
2107+
simplified_shape, simplified_src_strides,
2108+
simplified_dst_strides, src_offset, dst_offset);
2109+
2110+
nd += 2;
2111+
std::vector<py::ssize_t> shape_and_strides(3 * nd);
2112+
2113+
std::copy(simplified_shape.begin(), simplified_shape.end(),
2114+
shape_and_strides.begin());
2115+
shape_and_strides[nd - 2] = src_shape[src_nd - 2];
2116+
shape_and_strides[nd - 1] = src_shape[src_nd - 1];
2117+
std::copy(simplified_src_strides.begin(), simplified_src_strides.end(),
2118+
shape_and_strides.begin() + nd);
2119+
shape_and_strides[2 * nd - 2] = src_strides[src_nd - 2];
2120+
shape_and_strides[2 * nd - 1] = src_strides[src_nd - 1];
2121+
std::copy(simplified_dst_strides.begin(), simplified_dst_strides.end(),
2122+
shape_and_strides.begin() + 2 * nd);
2123+
shape_and_strides[3 * nd - 2] = dst_strides[src_nd - 2];
2124+
shape_and_strides[3 * nd - 1] = dst_strides[src_nd - 1];
2125+
2126+
std::shared_ptr<shT> shp_shape_and_strides =
2127+
std::make_shared<shT>(shape_and_strides);
2128+
2129+
py::ssize_t *dev_shape_and_strides =
2130+
sycl::malloc_device<ssize_t>(3 * nd, exec_q);
2131+
if (dev_shape_and_strides == nullptr) {
2132+
throw std::runtime_error("Unabled to allocate device memory");
2133+
}
2134+
sycl::event copy_shape_and_strides = exec_q.copy<ssize_t>(
2135+
shp_shape_and_strides->data(), dev_shape_and_strides, 3 * nd);
2136+
2137+
py::ssize_t inner_range =
2138+
shape_and_strides[nd - 1] * shape_and_strides[nd - 2];
2139+
py::ssize_t outer_range = src_nelems / inner_range;
2140+
2141+
sycl::event tri_ev;
2142+
if (part == 'l') {
2143+
auto fn = tril_generic_dispatch_vector[src_typeid];
2144+
tri_ev =
2145+
fn(exec_q, inner_range, outer_range, src.get_data(), dst.get_data(),
2146+
nd, dev_shape_and_strides, k, depends, {copy_shape_and_strides});
2147+
}
2148+
else {
2149+
auto fn = triu_generic_dispatch_vector[src_typeid];
2150+
tri_ev =
2151+
fn(exec_q, inner_range, outer_range, src.get_data(), dst.get_data(),
2152+
nd, dev_shape_and_strides, k, depends, {copy_shape_and_strides});
2153+
}
2154+
2155+
exec_q.submit([&](sycl::handler &cgh) {
2156+
cgh.depends_on({tri_ev});
2157+
auto ctx = exec_q.get_context();
2158+
cgh.host_task([shp_shape_and_strides, dev_shape_and_strides, ctx]() {
2159+
sycl::free(dev_shape_and_strides, ctx);
2160+
});
2161+
});
2162+
return std::make_pair(keep_args_alive(exec_q, {src, dst}, {tri_ev}),
2163+
tri_ev);
2164+
}
2165+
18822166
// populate dispatch tables
18832167
void init_copy_and_cast_dispatch_tables(void)
18842168
{
@@ -1936,6 +2220,12 @@ void init_copy_for_reshape_dispatch_vector(void)
19362220
DispatchVectorBuilder<eye_fn_ptr_t, EyeFactory, num_types> dvb4;
19372221
dvb4.populate_dispatch_vector(eye_dispatch_vector);
19382222

2223+
DispatchVectorBuilder<tri_fn_ptr_t, TrilGenericFactory, num_types> dvb5;
2224+
dvb5.populate_dispatch_vector(tril_generic_dispatch_vector);
2225+
2226+
DispatchVectorBuilder<tri_fn_ptr_t, TriuGenericFactory, num_types> dvb6;
2227+
dvb6.populate_dispatch_vector(triu_generic_dispatch_vector);
2228+
19392229
return;
19402230
}
19412231

@@ -2081,4 +2371,25 @@ PYBIND11_MODULE(_tensor_impl, m)
20812371
[](sycl::device dev) -> std::string {
20822372
return get_default_device_complex_type(dev);
20832373
});
2374+
m.def(
2375+
"_tril",
2376+
[](sycl::queue exec_q, dpctl::tensor::usm_ndarray src,
2377+
dpctl::tensor::usm_ndarray dst, int k,
2378+
const std::vector<sycl::event> depends)
2379+
-> std::pair<sycl::event, sycl::event> {
2380+
return tri(exec_q, src, dst, 'l', k, depends);
2381+
},
2382+
"Tril helper function.", py::arg("sycl_queue"), py::arg("src"),
2383+
py::arg("dst"), py::arg("k") = 0, py::arg("depends") = py::list());
2384+
2385+
m.def(
2386+
"_triu",
2387+
[](sycl::queue exec_q, dpctl::tensor::usm_ndarray src,
2388+
dpctl::tensor::usm_ndarray dst, int k,
2389+
const std::vector<sycl::event> depends)
2390+
-> std::pair<sycl::event, sycl::event> {
2391+
return tri(exec_q, src, dst, 'u', k, depends);
2392+
},
2393+
"Triu helper function.", py::arg("sycl_queue"), py::arg("src"),
2394+
py::arg("dst"), py::arg("k") = 0, py::arg("depends") = py::list());
20842395
}

0 commit comments

Comments
 (0)