@@ -1879,6 +1879,290 @@ eye(py::ssize_t k,
1879
1879
eye_event);
1880
1880
}
1881
1881
1882
+ /* =========================== Tril and triu ============================== */
1883
+ // define function type
1884
+ typedef sycl::event (*tri_fn_ptr_t )(sycl::queue,
1885
+ py::ssize_t , // inner_range //py::ssize_t
1886
+ py::ssize_t , // outer_range
1887
+ char *, // src_data_ptr
1888
+ char *, // dst_data_ptr
1889
+ py::ssize_t , // nd
1890
+ py::ssize_t *, // shape_and_strides
1891
+ int , // k
1892
+ const std::vector<sycl::event> &,
1893
+ const std::vector<sycl::event> &);
1894
+
1895
+ template <typename Ty, bool > class tri_kernel ;
1896
+ template <typename Ty, bool l>
1897
+ sycl::event tri_impl (sycl::queue exec_q,
1898
+ py::ssize_t inner_range,
1899
+ py::ssize_t outer_range,
1900
+ char *src_p,
1901
+ char *dst_p,
1902
+ py::ssize_t nd,
1903
+ py::ssize_t *shape_and_strides,
1904
+ int k,
1905
+ const std::vector<sycl::event> &depends,
1906
+ const std::vector<sycl::event> &additional_depends)
1907
+ {
1908
+ constexpr int d2 = 2 ;
1909
+ py::ssize_t src_s = nd;
1910
+ py::ssize_t dst_s = 2 * nd;
1911
+ py::ssize_t nd_1 = nd - 1 ;
1912
+ py::ssize_t nd_2 = nd - 2 ;
1913
+ Ty *src = reinterpret_cast <Ty *>(src_p);
1914
+ Ty *dst = reinterpret_cast <Ty *>(dst_p);
1915
+
1916
+ sycl::event tri_ev = exec_q.submit ([&](sycl::handler &cgh) {
1917
+ cgh.depends_on (depends);
1918
+ cgh.depends_on (additional_depends);
1919
+ cgh.parallel_for <tri_kernel<Ty, l>>(
1920
+ sycl::range<2 >(inner_range, outer_range), [=](sycl::item<2 > idx) {
1921
+ py::ssize_t src_inner_offset, dst_inner_offset;
1922
+ bool to_copy;
1923
+
1924
+ {
1925
+ py::ssize_t inner_gid = idx.get_id (0 );
1926
+ CIndexer_array<d2, py::ssize_t > indexer_i (
1927
+ {shape_and_strides[nd_2], shape_and_strides[nd_1]});
1928
+ indexer_i.set (inner_gid);
1929
+ const std::array<py::ssize_t , d2> &inner = indexer_i.get ();
1930
+ src_inner_offset =
1931
+ inner[0 ] * shape_and_strides[src_s + nd_2] +
1932
+ inner[1 ] * shape_and_strides[src_s + nd_1];
1933
+ dst_inner_offset =
1934
+ inner[0 ] * shape_and_strides[dst_s + nd_2] +
1935
+ inner[1 ] * shape_and_strides[dst_s + nd_1];
1936
+
1937
+ if (l)
1938
+ to_copy = (inner[0 ] + k >= inner[1 ]);
1939
+ else
1940
+ to_copy = (inner[0 ] + k <= inner[1 ]);
1941
+ }
1942
+
1943
+ py::ssize_t src_offset = 0 ;
1944
+ py::ssize_t dst_offset = 0 ;
1945
+ {
1946
+ py::ssize_t outer_gid = idx.get_id (1 );
1947
+ CIndexer_vector<py::ssize_t > outer (nd - d2);
1948
+ outer.get_displacement (
1949
+ outer_gid, shape_and_strides, shape_and_strides + src_s,
1950
+ shape_and_strides + dst_s, src_offset, dst_offset);
1951
+ }
1952
+
1953
+ src_offset += src_inner_offset;
1954
+ dst_offset += dst_inner_offset;
1955
+
1956
+ dst[dst_offset] = (to_copy) ? src[src_offset] : Ty (0 );
1957
+ });
1958
+ });
1959
+ return tri_ev;
1960
+ }
1961
+
1962
+ static tri_fn_ptr_t tril_generic_dispatch_vector[_ns::num_types];
1963
+
1964
+ template <typename fnT, typename Ty> struct TrilGenericFactory
1965
+ {
1966
+ fnT get ()
1967
+ {
1968
+ fnT f = tri_impl<Ty, /* tril*/ true >;
1969
+ return f;
1970
+ }
1971
+ };
1972
+
1973
+ static tri_fn_ptr_t triu_generic_dispatch_vector[_ns::num_types];
1974
+
1975
+ template <typename fnT, typename Ty> struct TriuGenericFactory
1976
+ {
1977
+ fnT get ()
1978
+ {
1979
+ fnT f = tri_impl<Ty, /* triu*/ false >;
1980
+ return f;
1981
+ }
1982
+ };
1983
+
1984
+ std::pair<sycl::event, sycl::event>
1985
+ tri (sycl::queue &exec_q,
1986
+ dpctl::tensor::usm_ndarray src,
1987
+ dpctl::tensor::usm_ndarray dst,
1988
+ char part,
1989
+ int k = 0 ,
1990
+ const std::vector<sycl::event> &depends = {})
1991
+ {
1992
+ // array dimensions must be the same
1993
+ int src_nd = src.get_ndim ();
1994
+ int dst_nd = dst.get_ndim ();
1995
+ if (src_nd != dst_nd) {
1996
+ throw py::value_error (" Array dimensions are not the same." );
1997
+ }
1998
+
1999
+ if (src_nd < 2 ) {
2000
+ throw py::value_error (" Array dimensions less than 2." );
2001
+ }
2002
+
2003
+ // shapes must be the same
2004
+ const py::ssize_t *src_shape = src.get_shape_raw ();
2005
+ const py::ssize_t *dst_shape = dst.get_shape_raw ();
2006
+
2007
+ bool shapes_equal (true );
2008
+ size_t src_nelems (1 );
2009
+
2010
+ for (int i = 0 ; i < src_nd; ++i) {
2011
+ src_nelems *= static_cast <size_t >(src_shape[i]);
2012
+ shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
2013
+ }
2014
+ if (!shapes_equal) {
2015
+ throw py::value_error (" Array shapes are not the same." );
2016
+ }
2017
+
2018
+ if (src_nelems == 0 ) {
2019
+ // nothing to do
2020
+ return std::make_pair (sycl::event (), sycl::event ());
2021
+ }
2022
+
2023
+ int src_typenum = src.get_typenum ();
2024
+ int dst_typenum = dst.get_typenum ();
2025
+ int src_typeid = array_types.typenum_to_lookup_id (src_typenum);
2026
+ int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
2027
+ if (dst_typeid != src_typeid) {
2028
+ throw py::value_error (" Array dtype are not the same." );
2029
+ }
2030
+
2031
+ // check same contexts
2032
+ sycl::queue src_q = src.get_queue ();
2033
+ sycl::queue dst_q = dst.get_queue ();
2034
+
2035
+ sycl::context exec_ctx = exec_q.get_context ();
2036
+ sycl::device exec_d = exec_q.get_device ();
2037
+ if (src_q.get_context () != exec_ctx || dst_q.get_context () != exec_ctx ||
2038
+ src_q.get_device () != exec_d || dst_q.get_device () != exec_d)
2039
+ {
2040
+ throw py::value_error (
2041
+ " Execution queue context is not the same as allocation contexts" );
2042
+ }
2043
+
2044
+ using shT = std::vector<py::ssize_t >;
2045
+ int src_flags = src.get_flags ();
2046
+ const py::ssize_t *src_strides_raw = src.get_strides_raw ();
2047
+ shT src_strides (src_nd);
2048
+ bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
2049
+ bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
2050
+ if (src_strides_raw == nullptr ) {
2051
+ if (is_src_c_contig) {
2052
+ src_strides = c_contiguous_strides (src_nd, src_shape);
2053
+ }
2054
+ else if (is_src_f_contig) {
2055
+ src_strides = f_contiguous_strides (src_nd, src_shape);
2056
+ }
2057
+ else {
2058
+ throw std::runtime_error (" Source array has null strides but has "
2059
+ " neither C- nor F- contiguous flag set" );
2060
+ }
2061
+ }
2062
+ else {
2063
+ for (ssize_t i = 0 ; i < src_nd; i++) {
2064
+ src_strides[i] = src_strides_raw[i];
2065
+ }
2066
+ }
2067
+
2068
+ int dst_flags = dst.get_flags ();
2069
+ const py::ssize_t *dst_strides_raw = dst.get_strides_raw ();
2070
+ shT dst_strides (src_nd);
2071
+ bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
2072
+ bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
2073
+ if (dst_strides_raw == nullptr ) {
2074
+ if (is_dst_c_contig) {
2075
+ dst_strides = c_contiguous_strides (src_nd, src_shape);
2076
+ }
2077
+ else if (is_dst_f_contig) {
2078
+ dst_strides = f_contiguous_strides (src_nd, src_shape);
2079
+ }
2080
+ else {
2081
+ throw std::runtime_error (" Source array has null strides but has "
2082
+ " neither C- nor F- contiguous flag set" );
2083
+ }
2084
+ }
2085
+ else {
2086
+ for (ssize_t i = 0 ; i < src_nd; i++) {
2087
+ dst_strides[i] = dst_strides_raw[i];
2088
+ }
2089
+ }
2090
+
2091
+ shT simplified_shape;
2092
+ shT simplified_src_strides;
2093
+ shT simplified_dst_strides;
2094
+ py::ssize_t src_offset (0 );
2095
+ py::ssize_t dst_offset (0 );
2096
+
2097
+ constexpr py::ssize_t src_itemsize = 1 ; // item size in elements
2098
+ constexpr py::ssize_t dst_itemsize = 1 ; // item size in elements
2099
+
2100
+ int nd = src_nd - 2 ;
2101
+ const py::ssize_t *shape = src_shape;
2102
+ const py::ssize_t *p_src_strides = &src_strides[0 ];
2103
+ const py::ssize_t *p_dst_strides = &dst_strides[0 ];
2104
+ simplify_iteration_space (nd, shape, p_src_strides, src_itemsize,
2105
+ is_src_c_contig, is_src_f_contig, p_dst_strides,
2106
+ dst_itemsize, is_dst_c_contig, is_dst_f_contig,
2107
+ simplified_shape, simplified_src_strides,
2108
+ simplified_dst_strides, src_offset, dst_offset);
2109
+
2110
+ nd += 2 ;
2111
+ std::vector<py::ssize_t > shape_and_strides (3 * nd);
2112
+
2113
+ std::copy (simplified_shape.begin (), simplified_shape.end (),
2114
+ shape_and_strides.begin ());
2115
+ shape_and_strides[nd - 2 ] = src_shape[src_nd - 2 ];
2116
+ shape_and_strides[nd - 1 ] = src_shape[src_nd - 1 ];
2117
+ std::copy (simplified_src_strides.begin (), simplified_src_strides.end (),
2118
+ shape_and_strides.begin () + nd);
2119
+ shape_and_strides[2 * nd - 2 ] = src_strides[src_nd - 2 ];
2120
+ shape_and_strides[2 * nd - 1 ] = src_strides[src_nd - 1 ];
2121
+ std::copy (simplified_dst_strides.begin (), simplified_dst_strides.end (),
2122
+ shape_and_strides.begin () + 2 * nd);
2123
+ shape_and_strides[3 * nd - 2 ] = dst_strides[src_nd - 2 ];
2124
+ shape_and_strides[3 * nd - 1 ] = dst_strides[src_nd - 1 ];
2125
+
2126
+ std::shared_ptr<shT> shp_shape_and_strides =
2127
+ std::make_shared<shT>(shape_and_strides);
2128
+
2129
+ py::ssize_t *dev_shape_and_strides =
2130
+ sycl::malloc_device<ssize_t >(3 * nd, exec_q);
2131
+ if (dev_shape_and_strides == nullptr ) {
2132
+ throw std::runtime_error (" Unabled to allocate device memory" );
2133
+ }
2134
+ sycl::event copy_shape_and_strides = exec_q.copy <ssize_t >(
2135
+ shp_shape_and_strides->data (), dev_shape_and_strides, 3 * nd);
2136
+
2137
+ py::ssize_t inner_range =
2138
+ shape_and_strides[nd - 1 ] * shape_and_strides[nd - 2 ];
2139
+ py::ssize_t outer_range = src_nelems / inner_range;
2140
+
2141
+ sycl::event tri_ev;
2142
+ if (part == ' l' ) {
2143
+ auto fn = tril_generic_dispatch_vector[src_typeid];
2144
+ tri_ev =
2145
+ fn (exec_q, inner_range, outer_range, src.get_data (), dst.get_data (),
2146
+ nd, dev_shape_and_strides, k, depends, {copy_shape_and_strides});
2147
+ }
2148
+ else {
2149
+ auto fn = triu_generic_dispatch_vector[src_typeid];
2150
+ tri_ev =
2151
+ fn (exec_q, inner_range, outer_range, src.get_data (), dst.get_data (),
2152
+ nd, dev_shape_and_strides, k, depends, {copy_shape_and_strides});
2153
+ }
2154
+
2155
+ exec_q.submit ([&](sycl::handler &cgh) {
2156
+ cgh.depends_on ({tri_ev});
2157
+ auto ctx = exec_q.get_context ();
2158
+ cgh.host_task ([shp_shape_and_strides, dev_shape_and_strides, ctx]() {
2159
+ sycl::free (dev_shape_and_strides, ctx);
2160
+ });
2161
+ });
2162
+ return std::make_pair (keep_args_alive (exec_q, {src, dst}, {tri_ev}),
2163
+ tri_ev);
2164
+ }
2165
+
1882
2166
// populate dispatch tables
1883
2167
void init_copy_and_cast_dispatch_tables (void )
1884
2168
{
@@ -1936,6 +2220,12 @@ void init_copy_for_reshape_dispatch_vector(void)
1936
2220
DispatchVectorBuilder<eye_fn_ptr_t , EyeFactory, num_types> dvb4;
1937
2221
dvb4.populate_dispatch_vector (eye_dispatch_vector);
1938
2222
2223
+ DispatchVectorBuilder<tri_fn_ptr_t , TrilGenericFactory, num_types> dvb5;
2224
+ dvb5.populate_dispatch_vector (tril_generic_dispatch_vector);
2225
+
2226
+ DispatchVectorBuilder<tri_fn_ptr_t , TriuGenericFactory, num_types> dvb6;
2227
+ dvb6.populate_dispatch_vector (triu_generic_dispatch_vector);
2228
+
1939
2229
return ;
1940
2230
}
1941
2231
@@ -2081,4 +2371,25 @@ PYBIND11_MODULE(_tensor_impl, m)
2081
2371
[](sycl::device dev) -> std::string {
2082
2372
return get_default_device_complex_type (dev);
2083
2373
});
2374
+ m.def (
2375
+ " _tril" ,
2376
+ [](sycl::queue exec_q, dpctl::tensor::usm_ndarray src,
2377
+ dpctl::tensor::usm_ndarray dst, int k,
2378
+ const std::vector<sycl::event> depends)
2379
+ -> std::pair<sycl::event, sycl::event> {
2380
+ return tri (exec_q, src, dst, ' l' , k, depends);
2381
+ },
2382
+ " Tril helper function." , py::arg (" sycl_queue" ), py::arg (" src" ),
2383
+ py::arg (" dst" ), py::arg (" k" ) = 0 , py::arg (" depends" ) = py::list ());
2384
+
2385
+ m.def (
2386
+ " _triu" ,
2387
+ [](sycl::queue exec_q, dpctl::tensor::usm_ndarray src,
2388
+ dpctl::tensor::usm_ndarray dst, int k,
2389
+ const std::vector<sycl::event> depends)
2390
+ -> std::pair<sycl::event, sycl::event> {
2391
+ return tri (exec_q, src, dst, ' u' , k, depends);
2392
+ },
2393
+ " Triu helper function." , py::arg (" sycl_queue" ), py::arg (" src" ),
2394
+ py::arg (" dst" ), py::arg (" k" ) = 0 , py::arg (" depends" ) = py::list ());
2084
2395
}
0 commit comments