@@ -1888,7 +1888,7 @@ typedef sycl::event (*tri_fn_ptr_t)(sycl::queue,
1888
1888
char *, // dst_data_ptr
1889
1889
py::ssize_t , // nd
1890
1890
py::ssize_t *, // shape_and_strides
1891
- int , // k
1891
+ py:: ssize_t , // k
1892
1892
const std::vector<sycl::event> &,
1893
1893
const std::vector<sycl::event> &);
1894
1894
@@ -1901,7 +1901,7 @@ sycl::event tri_impl(sycl::queue exec_q,
1901
1901
char *dst_p,
1902
1902
py::ssize_t nd,
1903
1903
py::ssize_t *shape_and_strides,
1904
- int k,
1904
+ py:: ssize_t k,
1905
1905
const std::vector<sycl::event> &depends,
1906
1906
const std::vector<sycl::event> &additional_depends)
1907
1907
{
@@ -1917,12 +1917,15 @@ sycl::event tri_impl(sycl::queue exec_q,
1917
1917
cgh.depends_on (depends);
1918
1918
cgh.depends_on (additional_depends);
1919
1919
cgh.parallel_for <tri_kernel<Ty, l>>(
1920
- sycl::range<2 >(inner_range, outer_range), [=](sycl::item<2 > idx) {
1920
+ sycl::range<1 >(inner_range * outer_range), [=](sycl::id<1 > idx) {
1921
+ py::ssize_t outer_gid = idx[0 ] / inner_range;
1922
+ py::ssize_t inner_gid = idx[0 ] - inner_range * outer_gid;
1923
+
1921
1924
py::ssize_t src_inner_offset, dst_inner_offset;
1922
1925
bool to_copy;
1923
1926
1924
1927
{
1925
- py::ssize_t inner_gid = idx.get_id (0 );
1928
+ // py::ssize_t inner_gid = idx.get_id(0);
1926
1929
CIndexer_array<d2, py::ssize_t > indexer_i (
1927
1930
{shape_and_strides[nd_2], shape_and_strides[nd_1]});
1928
1931
indexer_i.set (inner_gid);
@@ -1943,7 +1946,7 @@ sycl::event tri_impl(sycl::queue exec_q,
1943
1946
py::ssize_t src_offset = 0 ;
1944
1947
py::ssize_t dst_offset = 0 ;
1945
1948
{
1946
- py::ssize_t outer_gid = idx.get_id (1 );
1949
+ // py::ssize_t outer_gid = idx.get_id(1);
1947
1950
CIndexer_vector<py::ssize_t > outer (nd - d2);
1948
1951
outer.get_displacement (
1949
1952
outer_gid, shape_and_strides, shape_and_strides + src_s,
@@ -1986,7 +1989,7 @@ tri(sycl::queue &exec_q,
1986
1989
dpctl::tensor::usm_ndarray src,
1987
1990
dpctl::tensor::usm_ndarray dst,
1988
1991
char part,
1989
- int k = 0 ,
1992
+ py:: ssize_t k = 0 ,
1990
1993
const std::vector<sycl::event> &depends = {})
1991
1994
{
1992
1995
// array dimensions must be the same
@@ -2007,7 +2010,7 @@ tri(sycl::queue &exec_q,
2007
2010
bool shapes_equal (true );
2008
2011
size_t src_nelems (1 );
2009
2012
2010
- for (int i = 0 ; i < src_nd; ++i) {
2013
+ for (int i = 0 ; shapes_equal && i < src_nd; ++i) {
2011
2014
src_nelems *= static_cast <size_t >(src_shape[i]);
2012
2015
shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
2013
2016
}
@@ -2032,11 +2035,7 @@ tri(sycl::queue &exec_q,
2032
2035
sycl::queue src_q = src.get_queue ();
2033
2036
sycl::queue dst_q = dst.get_queue ();
2034
2037
2035
- sycl::context exec_ctx = exec_q.get_context ();
2036
- sycl::device exec_d = exec_q.get_device ();
2037
- if (src_q.get_context () != exec_ctx || dst_q.get_context () != exec_ctx ||
2038
- src_q.get_device () != exec_d || dst_q.get_device () != exec_d)
2039
- {
2038
+ if (!dpctl::utils::queues_are_compatible (exec_q, {src_q, dst_q})) {
2040
2039
throw py::value_error (
2041
2040
" Execution queue context is not the same as allocation contexts" );
2042
2041
}
@@ -2060,9 +2059,8 @@ tri(sycl::queue &exec_q,
2060
2059
}
2061
2060
}
2062
2061
else {
2063
- for (ssize_t i = 0 ; i < src_nd; i++) {
2064
- src_strides[i] = src_strides_raw[i];
2065
- }
2062
+ std::copy (src_strides_raw, src_strides_raw + src_nd,
2063
+ src_strides.begin ());
2066
2064
}
2067
2065
2068
2066
int dst_flags = dst.get_flags ();
@@ -2083,9 +2081,8 @@ tri(sycl::queue &exec_q,
2083
2081
}
2084
2082
}
2085
2083
else {
2086
- for (ssize_t i = 0 ; i < src_nd; i++) {
2087
- dst_strides[i] = dst_strides_raw[i];
2088
- }
2084
+ std::copy (dst_strides_raw, dst_strides_raw + dst_nd,
2085
+ dst_strides.begin ());
2089
2086
}
2090
2087
2091
2088
shT simplified_shape;
@@ -2099,14 +2096,20 @@ tri(sycl::queue &exec_q,
2099
2096
2100
2097
int nd = src_nd - 2 ;
2101
2098
const py::ssize_t *shape = src_shape;
2102
- const py::ssize_t *p_src_strides = &src_strides[0 ];
2103
- const py::ssize_t *p_dst_strides = &dst_strides[0 ];
2099
+ const py::ssize_t *p_src_strides = src_strides.data ();
2100
+ ;
2101
+ const py::ssize_t *p_dst_strides = dst_strides.data ();
2102
+ ;
2104
2103
simplify_iteration_space (nd, shape, p_src_strides, src_itemsize,
2105
2104
is_src_c_contig, is_src_f_contig, p_dst_strides,
2106
2105
dst_itemsize, is_dst_c_contig, is_dst_f_contig,
2107
2106
simplified_shape, simplified_src_strides,
2108
2107
simplified_dst_strides, src_offset, dst_offset);
2109
2108
2109
+ if (src_offset != 0 || dst_offset != 0 ) {
2110
+ throw py::value_error (" Reversed slice for dst is not supported" );
2111
+ }
2112
+
2110
2113
nd += 2 ;
2111
2114
std::vector<py::ssize_t > shape_and_strides (3 * nd);
2112
2115
@@ -2123,7 +2126,7 @@ tri(sycl::queue &exec_q,
2123
2126
shape_and_strides[3 * nd - 2 ] = dst_strides[src_nd - 2 ];
2124
2127
shape_and_strides[3 * nd - 1 ] = dst_strides[src_nd - 1 ];
2125
2128
2126
- std::shared_ptr<shT> shp_shape_and_strides =
2129
+ std::shared_ptr<shT> shp_host_shape_and_strides =
2127
2130
std::make_shared<shT>(shape_and_strides);
2128
2131
2129
2132
py::ssize_t *dev_shape_and_strides =
@@ -2132,7 +2135,7 @@ tri(sycl::queue &exec_q,
2132
2135
throw std::runtime_error (" Unabled to allocate device memory" );
2133
2136
}
2134
2137
sycl::event copy_shape_and_strides = exec_q.copy <ssize_t >(
2135
- shp_shape_and_strides ->data (), dev_shape_and_strides, 3 * nd);
2138
+ shp_host_shape_and_strides ->data (), dev_shape_and_strides, 3 * nd);
2136
2139
2137
2140
py::ssize_t inner_range =
2138
2141
shape_and_strides[nd - 1 ] * shape_and_strides[nd - 2 ];
@@ -2155,9 +2158,12 @@ tri(sycl::queue &exec_q,
2155
2158
exec_q.submit ([&](sycl::handler &cgh) {
2156
2159
cgh.depends_on ({tri_ev});
2157
2160
auto ctx = exec_q.get_context ();
2158
- cgh.host_task ([shp_shape_and_strides, dev_shape_and_strides, ctx]() {
2159
- sycl::free (dev_shape_and_strides, ctx);
2160
- });
2161
+ cgh.host_task (
2162
+ [shp_host_shape_and_strides, dev_shape_and_strides, ctx]() {
2163
+ // capture of shp_host_shape_and_strides ensure the underlying
2164
+ // vector exists for the entire execution of copying kernel
2165
+ sycl::free (dev_shape_and_strides, ctx);
2166
+ });
2161
2167
});
2162
2168
return std::make_pair (keep_args_alive (exec_q, {src, dst}, {tri_ev}),
2163
2169
tri_ev);
@@ -2373,23 +2379,25 @@ PYBIND11_MODULE(_tensor_impl, m)
2373
2379
});
2374
2380
m.def (
2375
2381
" _tril" ,
2376
- [](sycl::queue exec_q , dpctl::tensor::usm_ndarray src ,
2377
- dpctl::tensor::usm_ndarray dst, int k ,
2382
+ [](dpctl::tensor::usm_ndarray src , dpctl::tensor::usm_ndarray dst ,
2383
+ py:: ssize_t k, sycl::queue exec_q ,
2378
2384
const std::vector<sycl::event> depends)
2379
2385
-> std::pair<sycl::event, sycl::event> {
2380
2386
return tri (exec_q, src, dst, ' l' , k, depends);
2381
2387
},
2382
- " Tril helper function." , py::arg (" sycl_queue" ), py::arg (" src" ),
2383
- py::arg (" dst" ), py::arg (" k" ) = 0 , py::arg (" depends" ) = py::list ());
2388
+ " Tril helper function." , py::arg (" src" ), py::arg (" dst" ),
2389
+ py::arg (" k" ) = 0 , py::arg (" sycl_queue" ),
2390
+ py::arg (" depends" ) = py::list ());
2384
2391
2385
2392
m.def (
2386
2393
" _triu" ,
2387
- [](sycl::queue exec_q , dpctl::tensor::usm_ndarray src ,
2388
- dpctl::tensor::usm_ndarray dst, int k ,
2394
+ [](dpctl::tensor::usm_ndarray src , dpctl::tensor::usm_ndarray dst ,
2395
+ py:: ssize_t k, sycl::queue exec_q ,
2389
2396
const std::vector<sycl::event> depends)
2390
2397
-> std::pair<sycl::event, sycl::event> {
2391
2398
return tri (exec_q, src, dst, ' u' , k, depends);
2392
2399
},
2393
- " Triu helper function." , py::arg (" sycl_queue" ), py::arg (" src" ),
2394
- py::arg (" dst" ), py::arg (" k" ) = 0 , py::arg (" depends" ) = py::list ());
2400
+ " Triu helper function." , py::arg (" src" ), py::arg (" dst" ),
2401
+ py::arg (" k" ) = 0 , py::arg (" sycl_queue" ),
2402
+ py::arg (" depends" ) = py::list ());
2395
2403
}
0 commit comments