Skip to content

Commit e98de4d

Browse files
Implemented and deployed validate_type_for_device<T>(q)
This check would produce more succint error message: ``` In [1]: import dpctl.tensor as dpt In [2]: dpt.arange(0, 10, dtype='f8') --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Input In [2], in <cell line: 1>() ----> 1 dpt.arange(0, 10, dtype='f8') File ~/repos/dpctl/dpctl/tensor/_ctors.py:603, in arange(start, stop, step, dtype, device, usm_type, sycl_queue) 601 _step = sc_ty(1) 602 _start = _first --> 603 hev, _ = ti._linspace_step(_start, _step, res, sycl_queue) 604 hev.wait() 605 if is_bool: RuntimeError: Device Intel(R) Graphics [0x9a49] does not support type 'double' ```
1 parent 7d1904c commit e98de4d

File tree

3 files changed

+44
-0
lines changed

3 files changed

+44
-0
lines changed

dpctl/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ sycl::event lin_space_step_impl(sycl::queue exec_q,
129129
char *array_data,
130130
const std::vector<sycl::event> &depends)
131131
{
132+
dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q);
132133
sycl::event lin_space_step_event = exec_q.submit([&](sycl::handler &cgh) {
133134
cgh.depends_on(depends);
134135
cgh.parallel_for<linear_sequence_step_kernel<Ty>>(
@@ -270,6 +271,8 @@ sycl::event lin_space_affine_impl(sycl::queue exec_q,
270271
char *array_data,
271272
const std::vector<sycl::event> &depends)
272273
{
274+
dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q);
275+
273276
bool device_supports_doubles = exec_q.get_device().has(sycl::aspect::fp64);
274277
sycl::event lin_space_affine_event = exec_q.submit([&](sycl::handler &cgh) {
275278
cgh.depends_on(depends);
@@ -378,6 +381,7 @@ sycl::event full_contig_impl(sycl::queue q,
378381
char *dst_p,
379382
const std::vector<sycl::event> &depends)
380383
{
384+
dpctl::tensor::type_utils::validate_type_for_device<dstTy>(q);
381385
sycl::event fill_ev = q.submit([&](sycl::handler &cgh) {
382386
cgh.depends_on(depends);
383387
dstTy *p = reinterpret_cast<dstTy *>(dst_p);
@@ -496,6 +500,7 @@ sycl::event eye_impl(sycl::queue exec_q,
496500
char *array_data,
497501
const std::vector<sycl::event> &depends)
498502
{
503+
dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q);
499504
sycl::event eye_event = exec_q.submit([&](sycl::handler &cgh) {
500505
cgh.depends_on(depends);
501506
cgh.parallel_for<eye_kernel<Ty>>(
@@ -576,6 +581,8 @@ sycl::event tri_impl(sycl::queue exec_q,
576581
Ty *src = reinterpret_cast<Ty *>(src_p);
577582
Ty *dst = reinterpret_cast<Ty *>(dst_p);
578583

584+
dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q);
585+
579586
sycl::event tri_ev = exec_q.submit([&](sycl::handler &cgh) {
580587
cgh.depends_on(depends);
581588
cgh.depends_on(additional_depends);

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ copy_and_cast_generic_impl(sycl::queue q,
215215
const std::vector<sycl::event> &depends,
216216
const std::vector<sycl::event> &additional_depends)
217217
{
218+
dpctl::tensor::type_utils::validate_type_for_device<dstTy>(q);
219+
dpctl::tensor::type_utils::validate_type_for_device<srcTy>(q);
220+
218221
sycl::event copy_and_cast_ev = q.submit([&](sycl::handler &cgh) {
219222
cgh.depends_on(depends);
220223
cgh.depends_on(additional_depends);
@@ -317,6 +320,9 @@ copy_and_cast_nd_specialized_impl(sycl::queue q,
317320
py::ssize_t dst_offset,
318321
const std::vector<sycl::event> &depends)
319322
{
323+
dpctl::tensor::type_utils::validate_type_for_device<dstTy>(q);
324+
dpctl::tensor::type_utils::validate_type_for_device<srcTy>(q);
325+
320326
sycl::event copy_and_cast_ev = q.submit([&](sycl::handler &cgh) {
321327
cgh.depends_on(depends);
322328
cgh.parallel_for<copy_cast_spec_kernel<srcTy, dstTy, nd>>(
@@ -486,6 +492,10 @@ void copy_and_cast_from_host_impl(
486492
const std::vector<sycl::event> &additional_depends)
487493
{
488494
py::ssize_t nelems_range = src_max_nelem_offset - src_min_nelem_offset + 1;
495+
496+
dpctl::tensor::type_utils::validate_type_for_device<dstTy>(q);
497+
dpctl::tensor::type_utils::validate_type_for_device<srcTy>(q);
498+
489499
sycl::buffer<srcTy, 1> npy_buf(
490500
reinterpret_cast<const srcTy *>(host_src_p) + src_min_nelem_offset,
491501
sycl::range<1>(nelems_range), {sycl::property::buffer::use_host_ptr{}});
@@ -637,6 +647,8 @@ copy_for_reshape_generic_impl(sycl::queue q,
637647
char *dst_p,
638648
const std::vector<sycl::event> &depends)
639649
{
650+
dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
651+
640652
sycl::event copy_for_reshape_ev = q.submit([&](sycl::handler &cgh) {
641653
cgh.depends_on(depends);
642654
cgh.parallel_for<copy_for_reshape_generic_kernel<Ty>>(

dpctl/tensor/libtensor/include/utils/type_utils.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
//===----------------------------------------------------------------------===//
2424

2525
#pragma once
26+
#include <CL/sycl.hpp>
2627
#include <complex>
28+
#include <exception>
2729

2830
namespace dpctl
2931
{
@@ -68,6 +70,29 @@ template <typename dstTy, typename srcTy> dstTy convert_impl(const srcTy &v)
6870
}
6971
}
7072

73+
template <typename T> void validate_type_for_device(const sycl::device &d)
74+
{
75+
if constexpr (std::is_same_v<T, double>) {
76+
if (!d.has(sycl::aspect::fp64)) {
77+
throw std::runtime_error("Device " +
78+
d.get_info<sycl::info::device::name>() +
79+
" does not support type 'double'");
80+
}
81+
}
82+
else if constexpr (std::is_same_v<T, sycl::half>) {
83+
if (!d.has(sycl::aspect::fp16)) {
84+
throw std::runtime_error("Device " +
85+
d.get_info<sycl::info::device::name>() +
86+
" does not support type 'half'");
87+
}
88+
}
89+
}
90+
91+
template <typename T> void validate_type_for_device(const sycl::queue &q)
92+
{
93+
validate_type_for_device<T>(q.get_device());
94+
}
95+
7196
} // namespace type_utils
7297
} // namespace tensor
7398
} // namespace dpctl

0 commit comments

Comments
 (0)