Skip to content

[SYCL][InvokeSIMD] Add error for return type and subgroup size mismatch #8741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/invoke_simd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,19 @@ template <typename T> struct unwrap_uniform<uniform<T>> {
static T impl(uniform<T> val) { return val; }
};

// Verify the callee return type matches the subgroup size as is required by the
// spec. For example: simd<int, 8> foo(simd<int,16>); The return type vector
// length (8) does not match the subgroup size (16).
template <auto SgSize, typename SimdRet>
constexpr void verify_return_type_matches_sg_size() {
if constexpr (is_simd_or_mask_type<SimdRet>::value) {
constexpr auto RetVecLength = SimdRet::size();
static_assert(RetVecLength == SgSize,
"invoke_simd callee return type vector length must match "
"kernel subgroup size");
}
}

// Deduces subgroup size of the caller based on given SIMD callable and
// corresponding SPMD arguments it is being invoke with via invoke_simd.
// Basically, for each supported subgroup size, this meta-function finds out if
Expand Down Expand Up @@ -349,6 +362,8 @@ __attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
// is fine in this case.
constexpr int N = detail::get_sg_size<Callable, T...>();
using RetSpmd = detail::SpmdRetType<N, Callable, T...>;
detail::verify_return_type_matches_sg_size<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arguments need to pass similar check:
if sg-size is N, and one of operands is simd<T, N*2>, then it is the reportable error as well.

Copy link
Contributor Author

@sarnex sarnex Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I originally implemented that, but it turns out we already check for this when picking a subgroup size.

Example:

void SIMD_CALLEE( simd<float, VL> b,
                               simd<float, VL*2> c )  SYCL_ESIMD_FUNCTION {

Error:

error: static assertion failed due to requirement '__MP11_NS::integral_constant<unsigned long, 0>::value == 1'
    static_assert((__MP11_NS::mp_size<InvocableSgSizes>::value == 1) &&
                  "no or multiple invoke_simd targets found");

We didn't do the return type check there because we need the finalized subgroup size for the return type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thank you for explanations.
The comment on the line 361 concerns me a bit.
It says 0 is fine. If N is 0 (i.e. all parameters are uniform), then verify_return_type_matches_sg_size would fail, right? If Yes, then this special case needs to be handled specially in the static_assert that was added in this PR.

Copy link
Contributor Author

@sarnex sarnex Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't fail because in the case where all arguments are uniform but the return type is simd or simd mask, we pick the return type VL as subgroup size already. See the below code:

if constexpr (all_uniform_types<SpmdArgs...>()) {
    using SimdRet = std::invoke_result_t<SimdCallable, SpmdArgs...>;

    if constexpr (is_simd_or_mask_type<SimdRet>::value) {
      return simd_size<SimdRet>::value;
    } else {
      // fully uniform function - subgroup size does not matter
      return 0;
    }

So I don't think there is a way to hit the added static_assert because sg size will never be 0 if we go into the constexpr if I added

N, detail::SimdRetType<N, Callable, T...>>();
constexpr bool is_function = detail::is_function_ptr_or_ref_v<Callable>;

if constexpr (is_function) {
Expand Down