Skip to content

Commit 5e7b627

Browse files
committed
Argmin and argmax now handle identities correctly
Adds a test for this behavior Fixed a typo in argmin and argmax causing shared local memory variant to be used for more types than expected
1 parent 9402742 commit 5e7b627

File tree

2 files changed

+127
-86
lines changed

2 files changed

+127
-86
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 117 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,25 +1665,37 @@ struct SearchReduction
16651665
auto inp_offset = inp_iter_offset + inp_reduction_offset;
16661666

16671667
argT val = inp_[inp_offset];
1668-
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1669-
if (val < local_red_val) {
1670-
local_red_val = val;
1671-
if constexpr (!First) {
1672-
local_idx = inds_[inp_offset];
1673-
}
1674-
else {
1675-
local_idx = static_cast<outT>(arg_reduce_gid);
1676-
}
1668+
if (val == local_red_val) {
1669+
if constexpr (!First) {
1670+
local_idx = std::min(local_idx, inds_[inp_offset]);
1671+
}
1672+
else {
1673+
local_idx = std::min(local_idx,
1674+
static_cast<outT>(arg_reduce_gid));
16771675
}
16781676
}
1679-
else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
1680-
if (val > local_red_val) {
1681-
local_red_val = val;
1682-
if constexpr (!First) {
1683-
local_idx = inds_[inp_offset];
1677+
else {
1678+
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1679+
if (val < local_red_val) {
1680+
local_red_val = val;
1681+
if constexpr (!First) {
1682+
local_idx = inds_[inp_offset];
1683+
}
1684+
else {
1685+
local_idx = static_cast<outT>(arg_reduce_gid);
1686+
}
16841687
}
1685-
else {
1686-
local_idx = static_cast<outT>(arg_reduce_gid);
1688+
}
1689+
else if constexpr (su_ns::IsMaximum<argT,
1690+
ReductionOp>::value) {
1691+
if (val > local_red_val) {
1692+
local_red_val = val;
1693+
if constexpr (!First) {
1694+
local_idx = inds_[inp_offset];
1695+
}
1696+
else {
1697+
local_idx = static_cast<outT>(arg_reduce_gid);
1698+
}
16871699
}
16881700
}
16891701
}
@@ -1808,83 +1820,102 @@ struct CustomSearchReduction
18081820
auto inp_offset = inp_iter_offset + inp_reduction_offset;
18091821

18101822
argT val = inp_[inp_offset];
1811-
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1812-
using dpctl::tensor::type_utils::is_complex;
1813-
if constexpr (is_complex<argT>::value) {
1814-
using dpctl::tensor::math_utils::less_complex;
1815-
// less_complex always returns false for NaNs, so check
1816-
if (less_complex<argT>(val, local_red_val) ||
1817-
std::isnan(std::real(val)) ||
1818-
std::isnan(std::imag(val)))
1819-
{
1820-
local_red_val = val;
1821-
if constexpr (!First) {
1822-
local_idx = inds_[inp_offset];
1823-
}
1824-
else {
1825-
local_idx = static_cast<outT>(arg_reduce_gid);
1826-
}
1827-
}
1828-
}
1829-
else if constexpr (std::is_floating_point_v<argT>) {
1830-
if (val < local_red_val || std::isnan(val)) {
1831-
local_red_val = val;
1832-
if constexpr (!First) {
1833-
local_idx = inds_[inp_offset];
1834-
}
1835-
else {
1836-
local_idx = static_cast<outT>(arg_reduce_gid);
1837-
}
1838-
}
1823+
if (val == local_red_val) {
1824+
if constexpr (!First) {
1825+
local_idx = std::min(local_idx, inds_[inp_offset]);
18391826
}
18401827
else {
1841-
if (val < local_red_val) {
1842-
local_red_val = val;
1843-
if constexpr (!First) {
1844-
local_idx = inds_[inp_offset];
1845-
}
1846-
else {
1847-
local_idx = static_cast<outT>(arg_reduce_gid);
1848-
}
1849-
}
1828+
local_idx = std::min(local_idx,
1829+
static_cast<outT>(arg_reduce_gid));
18501830
}
18511831
}
1852-
else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
1853-
using dpctl::tensor::type_utils::is_complex;
1854-
if constexpr (is_complex<argT>::value) {
1855-
using dpctl::tensor::math_utils::greater_complex;
1856-
if (greater_complex<argT>(val, local_red_val) ||
1857-
std::isnan(std::real(val)) ||
1858-
std::isnan(std::imag(val)))
1859-
{
1860-
local_red_val = val;
1861-
if constexpr (!First) {
1862-
local_idx = inds_[inp_offset];
1863-
}
1864-
else {
1865-
local_idx = static_cast<outT>(arg_reduce_gid);
1832+
else {
1833+
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1834+
using dpctl::tensor::type_utils::is_complex;
1835+
if constexpr (is_complex<argT>::value) {
1836+
using dpctl::tensor::math_utils::less_complex;
1837+
// less_complex always returns false for NaNs, so
1838+
// check
1839+
if (less_complex<argT>(val, local_red_val) ||
1840+
std::isnan(std::real(val)) ||
1841+
std::isnan(std::imag(val)))
1842+
{
1843+
local_red_val = val;
1844+
if constexpr (!First) {
1845+
local_idx = inds_[inp_offset];
1846+
}
1847+
else {
1848+
local_idx =
1849+
static_cast<outT>(arg_reduce_gid);
1850+
}
18661851
}
18671852
}
1868-
}
1869-
else if constexpr (std::is_floating_point_v<argT>) {
1870-
if (val > local_red_val || std::isnan(val)) {
1871-
local_red_val = val;
1872-
if constexpr (!First) {
1873-
local_idx = inds_[inp_offset];
1853+
else if constexpr (std::is_floating_point_v<argT>) {
1854+
if (val < local_red_val || std::isnan(val)) {
1855+
local_red_val = val;
1856+
if constexpr (!First) {
1857+
local_idx = inds_[inp_offset];
1858+
}
1859+
else {
1860+
local_idx =
1861+
static_cast<outT>(arg_reduce_gid);
1862+
}
18741863
}
1875-
else {
1876-
local_idx = static_cast<outT>(arg_reduce_gid);
1864+
}
1865+
else {
1866+
if (val < local_red_val) {
1867+
local_red_val = val;
1868+
if constexpr (!First) {
1869+
local_idx = inds_[inp_offset];
1870+
}
1871+
else {
1872+
local_idx =
1873+
static_cast<outT>(arg_reduce_gid);
1874+
}
18771875
}
18781876
}
18791877
}
1880-
else {
1881-
if (val > local_red_val) {
1882-
local_red_val = val;
1883-
if constexpr (!First) {
1884-
local_idx = inds_[inp_offset];
1878+
else if constexpr (su_ns::IsMaximum<argT,
1879+
ReductionOp>::value) {
1880+
using dpctl::tensor::type_utils::is_complex;
1881+
if constexpr (is_complex<argT>::value) {
1882+
using dpctl::tensor::math_utils::greater_complex;
1883+
if (greater_complex<argT>(val, local_red_val) ||
1884+
std::isnan(std::real(val)) ||
1885+
std::isnan(std::imag(val)))
1886+
{
1887+
local_red_val = val;
1888+
if constexpr (!First) {
1889+
local_idx = inds_[inp_offset];
1890+
}
1891+
else {
1892+
local_idx =
1893+
static_cast<outT>(arg_reduce_gid);
1894+
}
18851895
}
1886-
else {
1887-
local_idx = static_cast<outT>(arg_reduce_gid);
1896+
}
1897+
else if constexpr (std::is_floating_point_v<argT>) {
1898+
if (val > local_red_val || std::isnan(val)) {
1899+
local_red_val = val;
1900+
if constexpr (!First) {
1901+
local_idx = inds_[inp_offset];
1902+
}
1903+
else {
1904+
local_idx =
1905+
static_cast<outT>(arg_reduce_gid);
1906+
}
1907+
}
1908+
}
1909+
else {
1910+
if (val > local_red_val) {
1911+
local_red_val = val;
1912+
if constexpr (!First) {
1913+
local_idx = inds_[inp_offset];
1914+
}
1915+
else {
1916+
local_idx =
1917+
static_cast<outT>(arg_reduce_gid);
1918+
}
18881919
}
18891920
}
18901921
}
@@ -2037,7 +2068,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
20372068
sycl::range<1>{iter_nelems * reduction_groups * wg};
20382069
auto localRange = sycl::range<1>{wg};
20392070

2040-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
2071+
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
20412072
using KernelName = class search_reduction_over_group_temps_krn<
20422073
argTy, resTy, ReductionOpT, IndexOpT,
20432074
InputOutputIterIndexerT, ReductionIndexerT, true, true>;
@@ -2136,7 +2167,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
21362167
sycl::range<1>{iter_nelems * reduction_groups * wg};
21372168
auto localRange = sycl::range<1>{wg};
21382169

2139-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
2170+
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
21402171
using KernelName = class search_reduction_over_group_temps_krn<
21412172
argTy, resTy, ReductionOpT, IndexOpT,
21422173
InputOutputIterIndexerT, ReductionIndexerT, true, false>;
@@ -2216,7 +2247,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
22162247
auto globalRange =
22172248
sycl::range<1>{iter_nelems * reduction_groups_ * wg};
22182249
auto localRange = sycl::range<1>{wg};
2219-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
2250+
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
22202251
using KernelName =
22212252
class search_reduction_over_group_temps_krn<
22222253
argTy, resTy, ReductionOpT, IndexOpT,
@@ -2299,7 +2330,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
22992330
sycl::range<1>{iter_nelems * reduction_groups * wg};
23002331
auto localRange = sycl::range<1>{wg};
23012332

2302-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
2333+
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
23032334
using KernelName = class search_reduction_over_group_temps_krn<
23042335
argTy, resTy, ReductionOpT, IndexOpT,
23052336
InputOutputIterIndexerT, ReductionIndexerT, false, true>;

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,13 @@ def test_argmax_argmin_nan_propagation():
201201
x[idx] = complex(0, dpt.nan)
202202
assert dpt.argmax(x) == idx
203203
assert dpt.argmin(x) == idx
204+
205+
206+
def test_argmax_argmin_identities():
207+
# make sure that identity arrays work as expected
208+
get_queue_or_skip()
209+
210+
x = dpt.full(3, dpt.iinfo(dpt.int32).min, dtype="i4")
211+
assert dpt.argmax(x) == 0
212+
x = dpt.full(3, dpt.iinfo(dpt.int32).max, dtype="i4")
213+
assert dpt.argmin(x) == 0

0 commit comments

Comments
 (0)