Skip to content

Commit 257bc03

Browse files
committed
Refactors sum to use generic reduction templates
1 parent d1de259 commit 257bc03

File tree

9 files changed

+451
-1806
lines changed

9 files changed

+451
-1806
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ pybind11_add_module(${python_module_name} MODULE
4949
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp
5050
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
5151
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
52-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sum_reductions.cpp
5352
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
5453
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
5554
)

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,249 @@ struct MinOverAxis0AtomicContigFactory
15861586
}
15871587
};
15881588

1589+
// Sum
1590+
1591+
/* @brief Types supported by plus-reduction code based on atomic_ref */
1592+
template <typename argTy, typename outTy>
1593+
struct TypePairSupportDataForSumReductionAtomic
1594+
{
1595+
1596+
/* value if true a kernel for <argTy, outTy> must be instantiated, false
1597+
* otherwise */
1598+
static constexpr bool is_defined = std::disjunction< // disjunction is C++17
1599+
// feature, supported
1600+
// by DPC++ input bool
1601+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int32_t>,
1602+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint32_t>,
1603+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
1604+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint64_t>,
1605+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, float>,
1606+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, double>,
1607+
// input int8
1608+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
1609+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
1610+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, float>,
1611+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, double>,
1612+
// input uint8
1613+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int32_t>,
1614+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint32_t>,
1615+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int64_t>,
1616+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint64_t>,
1617+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, float>,
1618+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, double>,
1619+
// input int16
1620+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int32_t>,
1621+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int64_t>,
1622+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, float>,
1623+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, double>,
1624+
// input uint16
1625+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int32_t>,
1626+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint32_t>,
1627+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int64_t>,
1628+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint64_t>,
1629+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, float>,
1630+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, double>,
1631+
// input int32
1632+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int32_t>,
1633+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int64_t>,
1634+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, float>,
1635+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, double>,
1636+
// input uint32
1637+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint32_t>,
1638+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::int64_t>,
1639+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint64_t>,
1640+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, float>,
1641+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, double>,
1642+
// input int64
1643+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
1644+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
1645+
// input uint64
1646+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
1647+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
1648+
// input half
1649+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float>,
1650+
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
1651+
// input float
1652+
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
1653+
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
1654+
// input double
1655+
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
1656+
// fall-through
1657+
td_ns::NotDefinedEntry>::is_defined;
1658+
};
1659+
1660+
template <typename argTy, typename outTy>
1661+
struct TypePairSupportDataForSumReductionTemps
1662+
{
1663+
1664+
static constexpr bool is_defined = std::disjunction< // disjunction is C++17
1665+
// feature, supported
1666+
// by DPC++ input bool
1667+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int8_t>,
1668+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint8_t>,
1669+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int16_t>,
1670+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint16_t>,
1671+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int32_t>,
1672+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint32_t>,
1673+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
1674+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint64_t>,
1675+
1676+
// input int8_t
1677+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int8_t>,
1678+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int16_t>,
1679+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
1680+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
1681+
1682+
// input uint8_t
1683+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint8_t>,
1684+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int16_t>,
1685+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint16_t>,
1686+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int32_t>,
1687+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint32_t>,
1688+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int64_t>,
1689+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint64_t>,
1690+
1691+
// input int16_t
1692+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int16_t>,
1693+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int32_t>,
1694+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int64_t>,
1695+
1696+
// input uint16_t
1697+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint16_t>,
1698+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int32_t>,
1699+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint32_t>,
1700+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int64_t>,
1701+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint64_t>,
1702+
1703+
// input int32_t
1704+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int32_t>,
1705+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int64_t>,
1706+
1707+
// input uint32_t
1708+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint32_t>,
1709+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint64_t>,
1710+
1711+
// input int64_t
1712+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
1713+
1714+
// input uint32_t
1715+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
1716+
1717+
// input half
1718+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
1719+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float>,
1720+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, double>,
1721+
td_ns::
1722+
TypePairDefinedEntry<argTy, sycl::half, outTy, std::complex<float>>,
1723+
td_ns::TypePairDefinedEntry<argTy,
1724+
sycl::half,
1725+
outTy,
1726+
std::complex<double>>,
1727+
1728+
// input float
1729+
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
1730+
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
1731+
td_ns::TypePairDefinedEntry<argTy, float, outTy, std::complex<float>>,
1732+
td_ns::TypePairDefinedEntry<argTy, float, outTy, std::complex<double>>,
1733+
1734+
// input double
1735+
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
1736+
td_ns::TypePairDefinedEntry<argTy, double, outTy, std::complex<double>>,
1737+
1738+
// input std::complex
1739+
td_ns::TypePairDefinedEntry<argTy,
1740+
std::complex<float>,
1741+
outTy,
1742+
std::complex<float>>,
1743+
td_ns::TypePairDefinedEntry<argTy,
1744+
std::complex<float>,
1745+
outTy,
1746+
std::complex<double>>,
1747+
1748+
td_ns::TypePairDefinedEntry<argTy,
1749+
std::complex<double>,
1750+
outTy,
1751+
std::complex<double>>,
1752+
1753+
// fall-throug
1754+
td_ns::NotDefinedEntry>::is_defined;
1755+
};
1756+
1757+
template <typename fnT, typename srcTy, typename dstTy>
1758+
struct SumOverAxisAtomicStridedFactory
1759+
{
1760+
fnT get() const
1761+
{
1762+
if constexpr (TypePairSupportDataForSumReductionAtomic<
1763+
srcTy, dstTy>::is_defined)
1764+
{
1765+
using ReductionOpT = sycl::plus<dstTy>;
1766+
return dpctl::tensor::kernels::
1767+
reduction_over_group_with_atomics_strided_impl<srcTy, dstTy,
1768+
ReductionOpT>;
1769+
}
1770+
else {
1771+
return nullptr;
1772+
}
1773+
}
1774+
};
1775+
1776+
template <typename fnT, typename srcTy, typename dstTy>
1777+
struct SumOverAxisTempsStridedFactory
1778+
{
1779+
fnT get() const
1780+
{
1781+
if constexpr (TypePairSupportDataForSumReductionTemps<
1782+
srcTy, dstTy>::is_defined) {
1783+
using ReductionOpT = sycl::plus<dstTy>;
1784+
return dpctl::tensor::kernels::
1785+
reduction_over_group_temps_strided_impl<srcTy, dstTy,
1786+
ReductionOpT>;
1787+
}
1788+
else {
1789+
return nullptr;
1790+
}
1791+
}
1792+
};
1793+
1794+
template <typename fnT, typename srcTy, typename dstTy>
1795+
struct SumOverAxis1AtomicContigFactory
1796+
{
1797+
fnT get() const
1798+
{
1799+
if constexpr (TypePairSupportDataForSumReductionAtomic<
1800+
srcTy, dstTy>::is_defined)
1801+
{
1802+
using ReductionOpT = sycl::plus<dstTy>;
1803+
return dpctl::tensor::kernels::
1804+
reduction_axis1_over_group_with_atomics_contig_impl<
1805+
srcTy, dstTy, ReductionOpT>;
1806+
}
1807+
else {
1808+
return nullptr;
1809+
}
1810+
}
1811+
};
1812+
1813+
template <typename fnT, typename srcTy, typename dstTy>
1814+
struct SumOverAxis0AtomicContigFactory
1815+
{
1816+
fnT get() const
1817+
{
1818+
if constexpr (TypePairSupportDataForSumReductionAtomic<
1819+
srcTy, dstTy>::is_defined)
1820+
{
1821+
using ReductionOpT = sycl::plus<dstTy>;
1822+
return dpctl::tensor::kernels::
1823+
reduction_axis0_over_group_with_atomics_contig_impl<
1824+
srcTy, dstTy, ReductionOpT>;
1825+
}
1826+
else {
1827+
return nullptr;
1828+
}
1829+
}
1830+
};
1831+
15891832
// Argmax and Argmin
15901833

15911834
/* = Search reduction using reduce_over_group*/

0 commit comments

Comments
 (0)