@@ -1665,25 +1665,37 @@ struct SearchReduction
1665
1665
auto inp_offset = inp_iter_offset + inp_reduction_offset;
1666
1666
1667
1667
argT val = inp_[inp_offset];
1668
- if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1669
- if (val < local_red_val) {
1670
- local_red_val = val;
1671
- if constexpr (!First) {
1672
- local_idx = inds_[inp_offset];
1673
- }
1674
- else {
1675
- local_idx = static_cast <outT>(arg_reduce_gid);
1676
- }
1668
+ if (val == local_red_val) {
1669
+ if constexpr (!First) {
1670
+ local_idx = std::min (local_idx, inds_[inp_offset]);
1671
+ }
1672
+ else {
1673
+ local_idx = std::min (local_idx,
1674
+ static_cast <outT>(arg_reduce_gid));
1677
1675
}
1678
1676
}
1679
- else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
1680
- if (val > local_red_val) {
1681
- local_red_val = val;
1682
- if constexpr (!First) {
1683
- local_idx = inds_[inp_offset];
1677
+ else {
1678
+ if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1679
+ if (val < local_red_val) {
1680
+ local_red_val = val;
1681
+ if constexpr (!First) {
1682
+ local_idx = inds_[inp_offset];
1683
+ }
1684
+ else {
1685
+ local_idx = static_cast <outT>(arg_reduce_gid);
1686
+ }
1684
1687
}
1685
- else {
1686
- local_idx = static_cast <outT>(arg_reduce_gid);
1688
+ }
1689
+ else if constexpr (su_ns::IsMaximum<argT,
1690
+ ReductionOp>::value) {
1691
+ if (val > local_red_val) {
1692
+ local_red_val = val;
1693
+ if constexpr (!First) {
1694
+ local_idx = inds_[inp_offset];
1695
+ }
1696
+ else {
1697
+ local_idx = static_cast <outT>(arg_reduce_gid);
1698
+ }
1687
1699
}
1688
1700
}
1689
1701
}
@@ -1808,83 +1820,102 @@ struct CustomSearchReduction
1808
1820
auto inp_offset = inp_iter_offset + inp_reduction_offset;
1809
1821
1810
1822
argT val = inp_[inp_offset];
1811
- if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1812
- using dpctl::tensor::type_utils::is_complex;
1813
- if constexpr (is_complex<argT>::value) {
1814
- using dpctl::tensor::math_utils::less_complex;
1815
- // less_complex always returns false for NaNs, so check
1816
- if (less_complex<argT>(val, local_red_val) ||
1817
- std::isnan (std::real (val)) ||
1818
- std::isnan (std::imag (val)))
1819
- {
1820
- local_red_val = val;
1821
- if constexpr (!First) {
1822
- local_idx = inds_[inp_offset];
1823
- }
1824
- else {
1825
- local_idx = static_cast <outT>(arg_reduce_gid);
1826
- }
1827
- }
1828
- }
1829
- else if constexpr (std::is_floating_point_v<argT>) {
1830
- if (val < local_red_val || std::isnan (val)) {
1831
- local_red_val = val;
1832
- if constexpr (!First) {
1833
- local_idx = inds_[inp_offset];
1834
- }
1835
- else {
1836
- local_idx = static_cast <outT>(arg_reduce_gid);
1837
- }
1838
- }
1823
+ if (val == local_red_val) {
1824
+ if constexpr (!First) {
1825
+ local_idx = std::min (local_idx, inds_[inp_offset]);
1839
1826
}
1840
1827
else {
1841
- if (val < local_red_val) {
1842
- local_red_val = val;
1843
- if constexpr (!First) {
1844
- local_idx = inds_[inp_offset];
1845
- }
1846
- else {
1847
- local_idx = static_cast <outT>(arg_reduce_gid);
1848
- }
1849
- }
1828
+ local_idx = std::min (local_idx,
1829
+ static_cast <outT>(arg_reduce_gid));
1850
1830
}
1851
1831
}
1852
- else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
1853
- using dpctl::tensor::type_utils::is_complex;
1854
- if constexpr (is_complex<argT>::value) {
1855
- using dpctl::tensor::math_utils::greater_complex;
1856
- if (greater_complex<argT>(val, local_red_val) ||
1857
- std::isnan (std::real (val)) ||
1858
- std::isnan (std::imag (val)))
1859
- {
1860
- local_red_val = val;
1861
- if constexpr (!First) {
1862
- local_idx = inds_[inp_offset];
1863
- }
1864
- else {
1865
- local_idx = static_cast <outT>(arg_reduce_gid);
1832
+ else {
1833
+ if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
1834
+ using dpctl::tensor::type_utils::is_complex;
1835
+ if constexpr (is_complex<argT>::value) {
1836
+ using dpctl::tensor::math_utils::less_complex;
1837
+ // less_complex always returns false for NaNs, so
1838
+ // check
1839
+ if (less_complex<argT>(val, local_red_val) ||
1840
+ std::isnan (std::real (val)) ||
1841
+ std::isnan (std::imag (val)))
1842
+ {
1843
+ local_red_val = val;
1844
+ if constexpr (!First) {
1845
+ local_idx = inds_[inp_offset];
1846
+ }
1847
+ else {
1848
+ local_idx =
1849
+ static_cast <outT>(arg_reduce_gid);
1850
+ }
1866
1851
}
1867
1852
}
1868
- }
1869
- else if constexpr (std::is_floating_point_v<argT>) {
1870
- if (val > local_red_val || std::isnan (val)) {
1871
- local_red_val = val;
1872
- if constexpr (!First) {
1873
- local_idx = inds_[inp_offset];
1853
+ else if constexpr (std::is_floating_point_v<argT>) {
1854
+ if (val < local_red_val || std::isnan (val)) {
1855
+ local_red_val = val;
1856
+ if constexpr (!First) {
1857
+ local_idx = inds_[inp_offset];
1858
+ }
1859
+ else {
1860
+ local_idx =
1861
+ static_cast <outT>(arg_reduce_gid);
1862
+ }
1874
1863
}
1875
- else {
1876
- local_idx = static_cast <outT>(arg_reduce_gid);
1864
+ }
1865
+ else {
1866
+ if (val < local_red_val) {
1867
+ local_red_val = val;
1868
+ if constexpr (!First) {
1869
+ local_idx = inds_[inp_offset];
1870
+ }
1871
+ else {
1872
+ local_idx =
1873
+ static_cast <outT>(arg_reduce_gid);
1874
+ }
1877
1875
}
1878
1876
}
1879
1877
}
1880
- else {
1881
- if (val > local_red_val) {
1882
- local_red_val = val;
1883
- if constexpr (!First) {
1884
- local_idx = inds_[inp_offset];
1878
+ else if constexpr (su_ns::IsMaximum<argT,
1879
+ ReductionOp>::value) {
1880
+ using dpctl::tensor::type_utils::is_complex;
1881
+ if constexpr (is_complex<argT>::value) {
1882
+ using dpctl::tensor::math_utils::greater_complex;
1883
+ if (greater_complex<argT>(val, local_red_val) ||
1884
+ std::isnan (std::real (val)) ||
1885
+ std::isnan (std::imag (val)))
1886
+ {
1887
+ local_red_val = val;
1888
+ if constexpr (!First) {
1889
+ local_idx = inds_[inp_offset];
1890
+ }
1891
+ else {
1892
+ local_idx =
1893
+ static_cast <outT>(arg_reduce_gid);
1894
+ }
1885
1895
}
1886
- else {
1887
- local_idx = static_cast <outT>(arg_reduce_gid);
1896
+ }
1897
+ else if constexpr (std::is_floating_point_v<argT>) {
1898
+ if (val > local_red_val || std::isnan (val)) {
1899
+ local_red_val = val;
1900
+ if constexpr (!First) {
1901
+ local_idx = inds_[inp_offset];
1902
+ }
1903
+ else {
1904
+ local_idx =
1905
+ static_cast <outT>(arg_reduce_gid);
1906
+ }
1907
+ }
1908
+ }
1909
+ else {
1910
+ if (val > local_red_val) {
1911
+ local_red_val = val;
1912
+ if constexpr (!First) {
1913
+ local_idx = inds_[inp_offset];
1914
+ }
1915
+ else {
1916
+ local_idx =
1917
+ static_cast <outT>(arg_reduce_gid);
1918
+ }
1888
1919
}
1889
1920
}
1890
1921
}
@@ -2037,7 +2068,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2037
2068
sycl::range<1 >{iter_nelems * reduction_groups * wg};
2038
2069
auto localRange = sycl::range<1 >{wg};
2039
2070
2040
- if constexpr (su_ns::IsSyclOp<resTy , ReductionOpT>::value) {
2071
+ if constexpr (su_ns::IsSyclOp<argTy , ReductionOpT>::value) {
2041
2072
using KernelName = class search_reduction_over_group_temps_krn <
2042
2073
argTy, resTy, ReductionOpT, IndexOpT,
2043
2074
InputOutputIterIndexerT, ReductionIndexerT, true , true >;
@@ -2136,7 +2167,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2136
2167
sycl::range<1 >{iter_nelems * reduction_groups * wg};
2137
2168
auto localRange = sycl::range<1 >{wg};
2138
2169
2139
- if constexpr (su_ns::IsSyclOp<resTy , ReductionOpT>::value) {
2170
+ if constexpr (su_ns::IsSyclOp<argTy , ReductionOpT>::value) {
2140
2171
using KernelName = class search_reduction_over_group_temps_krn <
2141
2172
argTy, resTy, ReductionOpT, IndexOpT,
2142
2173
InputOutputIterIndexerT, ReductionIndexerT, true , false >;
@@ -2216,7 +2247,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2216
2247
auto globalRange =
2217
2248
sycl::range<1 >{iter_nelems * reduction_groups_ * wg};
2218
2249
auto localRange = sycl::range<1 >{wg};
2219
- if constexpr (su_ns::IsSyclOp<resTy , ReductionOpT>::value) {
2250
+ if constexpr (su_ns::IsSyclOp<argTy , ReductionOpT>::value) {
2220
2251
using KernelName =
2221
2252
class search_reduction_over_group_temps_krn <
2222
2253
argTy, resTy, ReductionOpT, IndexOpT,
@@ -2299,7 +2330,7 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2299
2330
sycl::range<1 >{iter_nelems * reduction_groups * wg};
2300
2331
auto localRange = sycl::range<1 >{wg};
2301
2332
2302
- if constexpr (su_ns::IsSyclOp<resTy , ReductionOpT>::value) {
2333
+ if constexpr (su_ns::IsSyclOp<argTy , ReductionOpT>::value) {
2303
2334
using KernelName = class search_reduction_over_group_temps_krn <
2304
2335
argTy, resTy, ReductionOpT, IndexOpT,
2305
2336
InputOutputIterIndexerT, ReductionIndexerT, false , true >;
0 commit comments