Skip to content

[SYCL] Relax device code recursion ban for constexpr calls in constexpr context. #2104

Closed
@kbobrovs

Description

@kbobrovs

SYCL spec restricts device code not to allow recursion.
In cases when recursion can be evaluated at compile-time, and the restriction can be lifted for programmer convenience. Even though the spec says

Recursion is not allowed in a SYCL kernel or any code called from the kernel

the intent is likely to reliably prevent recursive function code generation for the devices, the spec seems to need clarification.

With the proposed relaxation the following code should compile:

template <typename name, typename Func>
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
  kernelFunc();
}

static constexpr unsigned int getNextPowerOf2(unsigned int n,
                                              unsigned int k = 1) {
  return (k >= n) ? k : getNextPowerOf2(n, k * 2);
}

unsigned test_constexpr_recursion(unsigned int val) {
  unsigned int res = val;
  unsigned int *addr = &res;

  kernel_single_task<class ConstexprRecursionKernel>([=]() {
    // Compiler must evaluate recursion, no errors expected.
    constexpr unsigned int x = getNextPowerOf2(3);
    *addr += x;
  });
  return res;
}

But if getNextPowerOf2 is called from non-constexpr context, it still should not compile:

template <typename name, typename Func>
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
  kernelFunc(); //#call_kernelFunc // expected-note 3{{called by 'kernel_single_task<fake_kernel, (lambda at}}
}

static constexpr unsigned int getNextPowerOf2(unsigned int n,
                                              unsigned int k = 1) {
  return (k >= n) ? k : getNextPowerOf2(n, k * 2);
}

unsigned test_constexpr_recursion(unsigned int val) {
  unsigned int res = val;
  unsigned int *addr = &res;

  kernel_single_task<class ConstexprRecursionKernel>([=]() {
    // This call should still generate an error, as getNextPowerOf2
    //  will not be compile-time evaluated in this case and will be code-generated:
    unsigned int x = getNextPowerOf2(3);
    *addr += x;
  });
  return res;
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions