Skip to content

Commit 1bfd30d

Browse files
authored
[SYCL][InvokeSIMD] Add error for return type and subgroup size mismatch (#8741)
The spec doesn't allow this, but right now we don't error at all. Signed-off-by: Sarnie, Nick <nick.sarnie@intel.com>
1 parent fc039fd commit 1bfd30d

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

sycl/include/sycl/ext/oneapi/experimental/invoke_simd.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,19 @@ template <typename T> struct unwrap_uniform<uniform<T>> {
194194
static T impl(uniform<T> val) { return val; }
195195
};
196196

197+
// Verify the callee return type matches the subgroup size as is required by the
198+
// spec. For example: simd<int, 8> foo(simd<int,16>); The return type vector
199+
// length (8) does not match the subgroup size (16).
200+
template <auto SgSize, typename SimdRet>
201+
constexpr void verify_return_type_matches_sg_size() {
202+
if constexpr (is_simd_or_mask_type<SimdRet>::value) {
203+
constexpr auto RetVecLength = SimdRet::size();
204+
static_assert(RetVecLength == SgSize,
205+
"invoke_simd callee return type vector length must match "
206+
"kernel subgroup size");
207+
}
208+
}
209+
197210
// Deduces subgroup size of the caller based on given SIMD callable and
198211
// corresponding SPMD arguments it is being invoke with via invoke_simd.
199212
// Basically, for each supported subgroup size, this meta-function finds out if
@@ -349,6 +362,8 @@ __attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
349362
// is fine in this case.
350363
constexpr int N = detail::get_sg_size<Callable, T...>();
351364
using RetSpmd = detail::SpmdRetType<N, Callable, T...>;
365+
detail::verify_return_type_matches_sg_size<
366+
N, detail::SimdRetType<N, Callable, T...>>();
352367
constexpr bool is_function = detail::is_function_ptr_or_ref_v<Callable>;
353368

354369
if constexpr (is_function) {

0 commit comments

Comments
 (0)