@@ -194,6 +194,19 @@ template <typename T> struct unwrap_uniform<uniform<T>> {
194
194
static T impl (uniform<T> val) { return val; }
195
195
};
196
196
197
+ // Verify the callee return type matches the subgroup size as is required by the
198
+ // spec. For example: simd<int, 8> foo(simd<int,16>); The return type vector
199
+ // length (8) does not match the subgroup size (16).
200
+ template <auto SgSize, typename SimdRet>
201
+ constexpr void verify_return_type_matches_sg_size () {
202
+ if constexpr (is_simd_or_mask_type<SimdRet>::value) {
203
+ constexpr auto RetVecLength = SimdRet::size ();
204
+ static_assert (RetVecLength == SgSize,
205
+ " invoke_simd callee return type vector length must match "
206
+ " kernel subgroup size" );
207
+ }
208
+ }
209
+
197
210
// Deduces subgroup size of the caller based on given SIMD callable and
198
211
// corresponding SPMD arguments it is being invoke with via invoke_simd.
199
212
// Basically, for each supported subgroup size, this meta-function finds out if
@@ -349,6 +362,8 @@ __attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
349
362
// is fine in this case.
350
363
constexpr int N = detail::get_sg_size<Callable, T...>();
351
364
using RetSpmd = detail::SpmdRetType<N, Callable, T...>;
365
+ detail::verify_return_type_matches_sg_size<
366
+ N, detail::SimdRetType<N, Callable, T...>>();
352
367
constexpr bool is_function = detail::is_function_ptr_or_ref_v<Callable>;
353
368
354
369
if constexpr (is_function) {
0 commit comments