Skip to content

Commit ffa8c92

Browse files
committed
rebase and resolve conflicts
1 parent 1253889 commit ffa8c92

File tree

2 files changed

+30
-29
lines changed

2 files changed

+30
-29
lines changed

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,15 @@ struct CeilOutputType
8181
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
8282
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
8383
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
84-
}
84+
};
8585

8686
/**
8787
* @brief A factory to define pairs of supported types for which
8888
* MKL VM library provides support in oneapi::mkl::vm::conj<T> function.
8989
*
9090
* @tparam T Type of input vector `a` and of result vector `y`.
9191
*/
92+
template <typename T>
9293
struct ConjOutputType
9394
{
9495
using value_type = typename std::disjunction<

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,34 @@ PYBIND11_MODULE(_vm_impl, m)
124124
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
125125
}
126126

127+
// UnaryUfunc: ==== Conj(x) ====
128+
{
129+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
130+
vm_ext::ConjContigFactory>(
131+
conj_dispatch_vector);
132+
133+
auto conj_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
134+
const event_vecT &depends = {}) {
135+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
136+
conj_dispatch_vector);
137+
};
138+
m.def("_conj", conj_pyapi,
139+
"Call `conj` function from OneMKL VM library to compute "
140+
"conjugate of vector elements",
141+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
142+
py::arg("depends") = py::list());
143+
144+
auto conj_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
145+
arrayT dst) {
146+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
147+
conj_dispatch_vector);
148+
};
149+
m.def("_mkl_conj_to_call", conj_need_to_call_pyapi,
150+
"Check input arguments to answer if `conj` function from "
151+
"OneMKL VM library can be used",
152+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
153+
}
154+
127155
// UnaryUfunc: ==== Cos(x) ====
128156
{
129157
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
@@ -180,34 +208,6 @@ PYBIND11_MODULE(_vm_impl, m)
180208
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
181209
}
182210

183-
// UnaryUfunc: ==== Conj(x) ====
184-
{
185-
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
186-
vm_ext::ConjContigFactory>(
187-
conj_dispatch_vector);
188-
189-
auto conj_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
190-
const event_vecT &depends = {}) {
191-
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
192-
conj_dispatch_vector);
193-
};
194-
m.def("_conj", conj_pyapi,
195-
"Call `conj` function from OneMKL VM library to compute "
196-
"conjugate of vector elements",
197-
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
198-
py::arg("depends") = py::list());
199-
200-
auto conj_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
201-
arrayT dst) {
202-
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
203-
conj_dispatch_vector);
204-
};
205-
m.def("_mkl_conj_to_call", conj_need_to_call_pyapi,
206-
"Check input arguments to answer if `conj` function from "
207-
"OneMKL VM library can be used",
208-
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
209-
}
210-
211211
// UnaryUfunc: ==== Ln(x) ====
212212
{
213213
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

0 commit comments

Comments
 (0)