Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions numba_dpex/core/typing/dpnpdecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,16 @@ def install_operations(cls):
"ldexp",
"spacing",
"isnat",
"cbrt",
]
)

# A list of ufuncs that are in fact aliases of other ufuncs. They need to insert
# the resolve method, but not register the ufunc itself
_aliases = set(["bitwise_not", "mod", "abs"])
# TODO: A list of ufuncs that are in fact aliases of other ufuncs. They need
# to insert the resolve method, but not register the ufunc itself.
# In a meantime let's just register them as user functions:
# TODO: for some reason it affects "mod", but does not affect "bitwise_not" and
# "abs". May be mod is not an alias?
_aliases = {"bitwise_not", "abs"}

all_ufuncs = sum(
[
Expand Down
31 changes: 28 additions & 3 deletions numba_dpex/tests/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
# SPDX-License-Identifier: Apache-2.0

import contextlib
import inspect
import shutil
from functools import cache

import dpctl
import dpnp
import pytest

from numba_dpex import dpjit, numba_sem_version
from numba_dpex.core import config
from numba_dpex import config, dpjit, numba_sem_version


@cache
Expand Down Expand Up @@ -179,6 +179,14 @@ def get_complex_dtypes(device=None):
return dtypes


def get_int_dtypes(device=None):
"""
Build a list of integer types supported by DPNP based on device capabilities.
"""

return [dpnp.int32, dpnp.int64]


def get_float_dtypes(no_float16=True, device=None):
"""
Build a list of floating types supported by DPNP based on device capabilities.
Expand Down Expand Up @@ -227,7 +235,7 @@ def get_all_dtypes(

# add integer types
if not no_int:
dtypes.extend([dpnp.int32, dpnp.int64])
dtypes.extend(get_int_dtypes(device=dev))

# add floating types
if not no_float:
Expand Down Expand Up @@ -276,3 +284,20 @@ def skip_if_dtype_not_supported(dt, q_or_dev):
pytest.skip(
f"{dev.name} does not support half precision floating point type"
)


def num_required_arguments(func):
"""Returns number of required arguments of the functions. Does not work
with kwargs arguments."""
if func == dpnp.true_divide:
func = dpnp.divide

sig = inspect.signature(func)
params = sig.parameters
required_args = [
p
for p in params
if params[p].default == inspect._empty and p != "kwargs"
]

return len(required_args)
Loading