@@ -1586,6 +1586,249 @@ struct MinOverAxis0AtomicContigFactory
1586
1586
}
1587
1587
};
1588
1588
1589
+ // Sum
1590
+
1591
+ /* @brief Types supported by plus-reduction code based on atomic_ref */
1592
+ template <typename argTy, typename outTy>
1593
+ struct TypePairSupportDataForSumReductionAtomic
1594
+ {
1595
+
1596
+ /* value if true a kernel for <argTy, outTy> must be instantiated, false
1597
+ * otherwise */
1598
+ static constexpr bool is_defined = std::disjunction< // disjunction is C++17
1599
+ // feature, supported
1600
+ // by DPC++ input bool
1601
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int32_t >,
1602
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint32_t >,
1603
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int64_t >,
1604
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint64_t >,
1605
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, float >,
1606
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, double >,
1607
+ // input int8
1608
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int32_t >,
1609
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int64_t >,
1610
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, float >,
1611
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, double >,
1612
+ // input uint8
1613
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int32_t >,
1614
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint32_t >,
1615
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int64_t >,
1616
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint64_t >,
1617
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, float >,
1618
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, double >,
1619
+ // input int16
1620
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int32_t >,
1621
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int64_t >,
1622
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, float >,
1623
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, double >,
1624
+ // input uint16
1625
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int32_t >,
1626
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint32_t >,
1627
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int64_t >,
1628
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint64_t >,
1629
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, float >,
1630
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, double >,
1631
+ // input int32
1632
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int32_t >,
1633
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int64_t >,
1634
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, float >,
1635
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, double >,
1636
+ // input uint32
1637
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint32_t >,
1638
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::int64_t >,
1639
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint64_t >,
1640
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, float >,
1641
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, double >,
1642
+ // input int64
1643
+ td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, std::int64_t >,
1644
+ td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, double >,
1645
+ // input uint64
1646
+ td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, std::uint64_t >,
1647
+ td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, double >,
1648
+ // input half
1649
+ td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float >,
1650
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, double >,
1651
+ // input float
1652
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, float >,
1653
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, double >,
1654
+ // input double
1655
+ td_ns::TypePairDefinedEntry<argTy, double , outTy, double >,
1656
+ // fall-through
1657
+ td_ns::NotDefinedEntry>::is_defined;
1658
+ };
1659
+
1660
+ template <typename argTy, typename outTy>
1661
+ struct TypePairSupportDataForSumReductionTemps
1662
+ {
1663
+
1664
+ static constexpr bool is_defined = std::disjunction< // disjunction is C++17
1665
+ // feature, supported
1666
+ // by DPC++ input bool
1667
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int8_t >,
1668
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint8_t >,
1669
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int16_t >,
1670
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint16_t >,
1671
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int32_t >,
1672
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint32_t >,
1673
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int64_t >,
1674
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint64_t >,
1675
+
1676
+ // input int8_t
1677
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int8_t >,
1678
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int16_t >,
1679
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int32_t >,
1680
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int64_t >,
1681
+
1682
+ // input uint8_t
1683
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint8_t >,
1684
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int16_t >,
1685
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint16_t >,
1686
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int32_t >,
1687
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint32_t >,
1688
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int64_t >,
1689
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint64_t >,
1690
+
1691
+ // input int16_t
1692
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int16_t >,
1693
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int32_t >,
1694
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int64_t >,
1695
+
1696
+ // input uint16_t
1697
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint16_t >,
1698
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int32_t >,
1699
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint32_t >,
1700
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int64_t >,
1701
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint64_t >,
1702
+
1703
+ // input int32_t
1704
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int32_t >,
1705
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int64_t >,
1706
+
1707
+ // input uint32_t
1708
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint32_t >,
1709
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint64_t >,
1710
+
1711
+ // input int64_t
1712
+ td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, std::int64_t >,
1713
+
1714
+ // input uint32_t
1715
+ td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, std::uint64_t >,
1716
+
1717
+ // input half
1718
+ td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
1719
+ td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float >,
1720
+ td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, double >,
1721
+ td_ns::
1722
+ TypePairDefinedEntry<argTy, sycl::half, outTy, std::complex<float >>,
1723
+ td_ns::TypePairDefinedEntry<argTy,
1724
+ sycl::half,
1725
+ outTy,
1726
+ std::complex<double >>,
1727
+
1728
+ // input float
1729
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, float >,
1730
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, double >,
1731
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, std::complex<float >>,
1732
+ td_ns::TypePairDefinedEntry<argTy, float , outTy, std::complex<double >>,
1733
+
1734
+ // input double
1735
+ td_ns::TypePairDefinedEntry<argTy, double , outTy, double >,
1736
+ td_ns::TypePairDefinedEntry<argTy, double , outTy, std::complex<double >>,
1737
+
1738
+ // input std::complex
1739
+ td_ns::TypePairDefinedEntry<argTy,
1740
+ std::complex<float >,
1741
+ outTy,
1742
+ std::complex<float >>,
1743
+ td_ns::TypePairDefinedEntry<argTy,
1744
+ std::complex<float >,
1745
+ outTy,
1746
+ std::complex<double >>,
1747
+
1748
+ td_ns::TypePairDefinedEntry<argTy,
1749
+ std::complex<double >,
1750
+ outTy,
1751
+ std::complex<double >>,
1752
+
1753
+ // fall-throug
1754
+ td_ns::NotDefinedEntry>::is_defined;
1755
+ };
1756
+
1757
+ template <typename fnT, typename srcTy, typename dstTy>
1758
+ struct SumOverAxisAtomicStridedFactory
1759
+ {
1760
+ fnT get () const
1761
+ {
1762
+ if constexpr (TypePairSupportDataForSumReductionAtomic<
1763
+ srcTy, dstTy>::is_defined)
1764
+ {
1765
+ using ReductionOpT = sycl::plus<dstTy>;
1766
+ return dpctl::tensor::kernels::
1767
+ reduction_over_group_with_atomics_strided_impl<srcTy, dstTy,
1768
+ ReductionOpT>;
1769
+ }
1770
+ else {
1771
+ return nullptr ;
1772
+ }
1773
+ }
1774
+ };
1775
+
1776
+ template <typename fnT, typename srcTy, typename dstTy>
1777
+ struct SumOverAxisTempsStridedFactory
1778
+ {
1779
+ fnT get () const
1780
+ {
1781
+ if constexpr (TypePairSupportDataForSumReductionTemps<
1782
+ srcTy, dstTy>::is_defined) {
1783
+ using ReductionOpT = sycl::plus<dstTy>;
1784
+ return dpctl::tensor::kernels::
1785
+ reduction_over_group_temps_strided_impl<srcTy, dstTy,
1786
+ ReductionOpT>;
1787
+ }
1788
+ else {
1789
+ return nullptr ;
1790
+ }
1791
+ }
1792
+ };
1793
+
1794
+ template <typename fnT, typename srcTy, typename dstTy>
1795
+ struct SumOverAxis1AtomicContigFactory
1796
+ {
1797
+ fnT get () const
1798
+ {
1799
+ if constexpr (TypePairSupportDataForSumReductionAtomic<
1800
+ srcTy, dstTy>::is_defined)
1801
+ {
1802
+ using ReductionOpT = sycl::plus<dstTy>;
1803
+ return dpctl::tensor::kernels::
1804
+ reduction_axis1_over_group_with_atomics_contig_impl<
1805
+ srcTy, dstTy, ReductionOpT>;
1806
+ }
1807
+ else {
1808
+ return nullptr ;
1809
+ }
1810
+ }
1811
+ };
1812
+
1813
+ template <typename fnT, typename srcTy, typename dstTy>
1814
+ struct SumOverAxis0AtomicContigFactory
1815
+ {
1816
+ fnT get () const
1817
+ {
1818
+ if constexpr (TypePairSupportDataForSumReductionAtomic<
1819
+ srcTy, dstTy>::is_defined)
1820
+ {
1821
+ using ReductionOpT = sycl::plus<dstTy>;
1822
+ return dpctl::tensor::kernels::
1823
+ reduction_axis0_over_group_with_atomics_contig_impl<
1824
+ srcTy, dstTy, ReductionOpT>;
1825
+ }
1826
+ else {
1827
+ return nullptr ;
1828
+ }
1829
+ }
1830
+ };
1831
+
1589
1832
// Argmax and Argmin
1590
1833
1591
1834
/* = Search reduction using reduce_over_group*/
0 commit comments