Skip to content

Commit 82b4186

Browse files
committed
Minor changes
1 parent 1c67f18 commit 82b4186

File tree

1 file changed

+41
-33
lines changed

1 file changed

+41
-33
lines changed

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,7 +1888,7 @@ typedef sycl::event (*tri_fn_ptr_t)(sycl::queue,
18881888
char *, // dst_data_ptr
18891889
py::ssize_t, // nd
18901890
py::ssize_t *, // shape_and_strides
1891-
int, // k
1891+
py::ssize_t, // k
18921892
const std::vector<sycl::event> &,
18931893
const std::vector<sycl::event> &);
18941894

@@ -1901,7 +1901,7 @@ sycl::event tri_impl(sycl::queue exec_q,
19011901
char *dst_p,
19021902
py::ssize_t nd,
19031903
py::ssize_t *shape_and_strides,
1904-
int k,
1904+
py::ssize_t k,
19051905
const std::vector<sycl::event> &depends,
19061906
const std::vector<sycl::event> &additional_depends)
19071907
{
@@ -1917,12 +1917,15 @@ sycl::event tri_impl(sycl::queue exec_q,
19171917
cgh.depends_on(depends);
19181918
cgh.depends_on(additional_depends);
19191919
cgh.parallel_for<tri_kernel<Ty, l>>(
1920-
sycl::range<2>(inner_range, outer_range), [=](sycl::item<2> idx) {
1920+
sycl::range<1>(inner_range * outer_range), [=](sycl::id<1> idx) {
1921+
py::ssize_t outer_gid = idx[0] / inner_range;
1922+
py::ssize_t inner_gid = idx[0] - inner_range * outer_gid;
1923+
19211924
py::ssize_t src_inner_offset, dst_inner_offset;
19221925
bool to_copy;
19231926

19241927
{
1925-
py::ssize_t inner_gid = idx.get_id(0);
1928+
// py::ssize_t inner_gid = idx.get_id(0);
19261929
CIndexer_array<d2, py::ssize_t> indexer_i(
19271930
{shape_and_strides[nd_2], shape_and_strides[nd_1]});
19281931
indexer_i.set(inner_gid);
@@ -1943,7 +1946,7 @@ sycl::event tri_impl(sycl::queue exec_q,
19431946
py::ssize_t src_offset = 0;
19441947
py::ssize_t dst_offset = 0;
19451948
{
1946-
py::ssize_t outer_gid = idx.get_id(1);
1949+
// py::ssize_t outer_gid = idx.get_id(1);
19471950
CIndexer_vector<py::ssize_t> outer(nd - d2);
19481951
outer.get_displacement(
19491952
outer_gid, shape_and_strides, shape_and_strides + src_s,
@@ -1986,7 +1989,7 @@ tri(sycl::queue &exec_q,
19861989
dpctl::tensor::usm_ndarray src,
19871990
dpctl::tensor::usm_ndarray dst,
19881991
char part,
1989-
int k = 0,
1992+
py::ssize_t k = 0,
19901993
const std::vector<sycl::event> &depends = {})
19911994
{
19921995
// array dimensions must be the same
@@ -2007,7 +2010,7 @@ tri(sycl::queue &exec_q,
20072010
bool shapes_equal(true);
20082011
size_t src_nelems(1);
20092012

2010-
for (int i = 0; i < src_nd; ++i) {
2013+
for (int i = 0; shapes_equal && i < src_nd; ++i) {
20112014
src_nelems *= static_cast<size_t>(src_shape[i]);
20122015
shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
20132016
}
@@ -2032,11 +2035,7 @@ tri(sycl::queue &exec_q,
20322035
sycl::queue src_q = src.get_queue();
20332036
sycl::queue dst_q = dst.get_queue();
20342037

2035-
sycl::context exec_ctx = exec_q.get_context();
2036-
sycl::device exec_d = exec_q.get_device();
2037-
if (src_q.get_context() != exec_ctx || dst_q.get_context() != exec_ctx ||
2038-
src_q.get_device() != exec_d || dst_q.get_device() != exec_d)
2039-
{
2038+
if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) {
20402039
throw py::value_error(
20412040
"Execution queue context is not the same as allocation contexts");
20422041
}
@@ -2060,9 +2059,8 @@ tri(sycl::queue &exec_q,
20602059
}
20612060
}
20622061
else {
2063-
for (ssize_t i = 0; i < src_nd; i++) {
2064-
src_strides[i] = src_strides_raw[i];
2065-
}
2062+
std::copy(src_strides_raw, src_strides_raw + src_nd,
2063+
src_strides.begin());
20662064
}
20672065

20682066
int dst_flags = dst.get_flags();
@@ -2083,9 +2081,8 @@ tri(sycl::queue &exec_q,
20832081
}
20842082
}
20852083
else {
2086-
for (ssize_t i = 0; i < src_nd; i++) {
2087-
dst_strides[i] = dst_strides_raw[i];
2088-
}
2084+
std::copy(dst_strides_raw, dst_strides_raw + dst_nd,
2085+
dst_strides.begin());
20892086
}
20902087

20912088
shT simplified_shape;
@@ -2099,14 +2096,20 @@ tri(sycl::queue &exec_q,
20992096

21002097
int nd = src_nd - 2;
21012098
const py::ssize_t *shape = src_shape;
2102-
const py::ssize_t *p_src_strides = &src_strides[0];
2103-
const py::ssize_t *p_dst_strides = &dst_strides[0];
2099+
const py::ssize_t *p_src_strides = src_strides.data();
2100+
;
2101+
const py::ssize_t *p_dst_strides = dst_strides.data();
2102+
;
21042103
simplify_iteration_space(nd, shape, p_src_strides, src_itemsize,
21052104
is_src_c_contig, is_src_f_contig, p_dst_strides,
21062105
dst_itemsize, is_dst_c_contig, is_dst_f_contig,
21072106
simplified_shape, simplified_src_strides,
21082107
simplified_dst_strides, src_offset, dst_offset);
21092108

2109+
if (src_offset != 0 || dst_offset != 0) {
2110+
throw py::value_error("Reversed slice for dst is not supported");
2111+
}
2112+
21102113
nd += 2;
21112114
std::vector<py::ssize_t> shape_and_strides(3 * nd);
21122115

@@ -2123,7 +2126,7 @@ tri(sycl::queue &exec_q,
21232126
shape_and_strides[3 * nd - 2] = dst_strides[src_nd - 2];
21242127
shape_and_strides[3 * nd - 1] = dst_strides[src_nd - 1];
21252128

2126-
std::shared_ptr<shT> shp_shape_and_strides =
2129+
std::shared_ptr<shT> shp_host_shape_and_strides =
21272130
std::make_shared<shT>(shape_and_strides);
21282131

21292132
py::ssize_t *dev_shape_and_strides =
@@ -2132,7 +2135,7 @@ tri(sycl::queue &exec_q,
21322135
throw std::runtime_error("Unabled to allocate device memory");
21332136
}
21342137
sycl::event copy_shape_and_strides = exec_q.copy<ssize_t>(
2135-
shp_shape_and_strides->data(), dev_shape_and_strides, 3 * nd);
2138+
shp_host_shape_and_strides->data(), dev_shape_and_strides, 3 * nd);
21362139

21372140
py::ssize_t inner_range =
21382141
shape_and_strides[nd - 1] * shape_and_strides[nd - 2];
@@ -2155,9 +2158,12 @@ tri(sycl::queue &exec_q,
21552158
exec_q.submit([&](sycl::handler &cgh) {
21562159
cgh.depends_on({tri_ev});
21572160
auto ctx = exec_q.get_context();
2158-
cgh.host_task([shp_shape_and_strides, dev_shape_and_strides, ctx]() {
2159-
sycl::free(dev_shape_and_strides, ctx);
2160-
});
2161+
cgh.host_task(
2162+
[shp_host_shape_and_strides, dev_shape_and_strides, ctx]() {
2163+
// capture of shp_host_shape_and_strides ensure the underlying
2164+
// vector exists for the entire execution of copying kernel
2165+
sycl::free(dev_shape_and_strides, ctx);
2166+
});
21612167
});
21622168
return std::make_pair(keep_args_alive(exec_q, {src, dst}, {tri_ev}),
21632169
tri_ev);
@@ -2373,23 +2379,25 @@ PYBIND11_MODULE(_tensor_impl, m)
23732379
});
23742380
m.def(
23752381
"_tril",
2376-
[](sycl::queue exec_q, dpctl::tensor::usm_ndarray src,
2377-
dpctl::tensor::usm_ndarray dst, int k,
2382+
[](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst,
2383+
py::ssize_t k, sycl::queue exec_q,
23782384
const std::vector<sycl::event> depends)
23792385
-> std::pair<sycl::event, sycl::event> {
23802386
return tri(exec_q, src, dst, 'l', k, depends);
23812387
},
2382-
"Tril helper function.", py::arg("sycl_queue"), py::arg("src"),
2383-
py::arg("dst"), py::arg("k") = 0, py::arg("depends") = py::list());
2388+
"Tril helper function.", py::arg("src"), py::arg("dst"),
2389+
py::arg("k") = 0, py::arg("sycl_queue"),
2390+
py::arg("depends") = py::list());
23842391

23852392
m.def(
23862393
"_triu",
2387-
[](sycl::queue exec_q, dpctl::tensor::usm_ndarray src,
2388-
dpctl::tensor::usm_ndarray dst, int k,
2394+
[](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst,
2395+
py::ssize_t k, sycl::queue exec_q,
23892396
const std::vector<sycl::event> depends)
23902397
-> std::pair<sycl::event, sycl::event> {
23912398
return tri(exec_q, src, dst, 'u', k, depends);
23922399
},
2393-
"Triu helper function.", py::arg("sycl_queue"), py::arg("src"),
2394-
py::arg("dst"), py::arg("k") = 0, py::arg("depends") = py::list());
2400+
"Triu helper function.", py::arg("src"), py::arg("dst"),
2401+
py::arg("k") = 0, py::arg("sycl_queue"),
2402+
py::arg("depends") = py::list());
23952403
}

0 commit comments

Comments
 (0)