Skip to content

cbrt ufuncs does not work for dpnp #1295

Open
@ZzEeKkAa

Description

@ZzEeKkAa

Reproduction

  1. Add these lines to the end at numba_dpex/tests/dpjit_tests/parfors/test_dpnp_ufuncs.py:
if __name__ == "__main__":
    test_unary_ops("cbrt",dpnp.float32)
  1. Comment out cbrt from _unsupported at numba_dpex/core/typing/dpnpdecl.py
  2. Run
ONEAPI_DEVICE_SELECTOR=opencl:cpu python -m numba_dpex.tests.dpjit_tests.parfors.test_dpnp_ufuncs
  1. Output:
numba.core.errors.TypingError: Failed in dpex_kernel_nopython mode pipeline (step: numba_dpex qualified name disambiguation)
Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function np_real_cbrt_impl.<locals>.cbrt at 0x7f0f545c1940>) found for signature:

 >>> cbrt(float32)

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'register_jitable.<locals>.wrap.<locals>.ov_wrap': File: numba/core/extending.py: Line 161.
    With argument(s): '(float64)':
   Rejected as the implementation raised a specific error:
     KeyError: <class 'numba_dpex.core.targets.kernel_target.SyclDevice'>
  raised from /home/yevhenii/Projects/numba/numba/core/registry.py:100

During: resolving callee type: Function(<function np_real_cbrt_impl.<locals>.cbrt at 0x7f6e67bd9300>)
During: typing of call at /home/yevhenii/Projects/numba/numba/np/npyfuncs.py (839)


File "../numba/numba/np/npyfuncs.py", line 839:
    def _cbrt(x):
        <source elided>
            return np.nan
        return cbrt(x)
        ^

During: lowering "$expr_out_var.12 = call $8load_deref.1.14($arg_out_var.13, func=$8load_deref.1.14, args=[Var($arg_out_var.13, test_dpnp_ufuncs.py:169)], kws=(), vararg=None, varkwarg=None, target=None)" at /home/yevhenii/Projects/numba-dpex/numba_dpex/tests/dpjit_tests/parfors/test_dpnp_ufuncs.py (169)

Possible cause of issue

cbrt is defined with @register_jitable(fastmath=True) which may be not registered at kernel target.

Numba version

Equivalent code for numba+numpy works okay:

    import numba
    import numpy as np

    @numba.njit
    def func(a):
        return np.cbrt(a)
    
    a = np.ones(10)

    print(func(a))

Found in #1283

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingenhancementNew feature or requestkernel APIAbout @numba_dpex.kernel decorator

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions