Skip to content

Commit

Permalink
fix wrong A/BLayout in MMA_Traits for binary mma and append other MMA…
Browse files Browse the repository at this point in the history
…_Traits support (NVIDIA#1856)

* fix wrong A/BLayout in  MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> and append support for  m8n8k128, m16n8k128  mma.and.popc in MMA_Traits instantiation

* add "print" template for  subbyte_reference<T>
  • Loading branch information
CalebDu authored Oct 24, 2024
1 parent be692b4 commit a424ca6
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 4 deletions.
99 changes: 99 additions & 0 deletions include/cute/arch/mma_sm80.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2141,4 +2141,103 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC

////////////////////////////////////////////////////////////////////////////////////////////////////

// MMA 8x8x128 TN
struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[2];
using ARegisters = uint32_t[1];
using BRegisters = uint32_t[1];
using CRegisters = uint32_t[2];

CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1,
uint32_t const& a0,
uint32_t const& b0,
uint32_t const& c0, uint32_t const& c1)
{
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1},"
"{%2},"
"{%3},"
"{%4, %5};\n"
: "=r"(d0), "=r"(d1)
: "r"(a0),
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};

////////////////////////////////////////////////////////////////////////////////////////////////////

// MMA 16x8x128 TN
struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[1];
using CRegisters = uint32_t[4];

CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1, %2, %3},"
"{%4, %5},"
"{%6},"
"{%7, %8, %9, %10};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1),
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};

////////////////////////////////////////////////////////////////////////////////////////////////////

// MMA 16x8x256 TN
struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[4];

CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace cute
55 changes: 51 additions & 4 deletions include/cute/atom/mma_traits_sm80.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,57 @@ struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC>

using Shape_MNK = Shape<_16,_8,_256>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape <_32,Shape < _8, _4,_2, _2>>,
Stride<_64,Stride<_64,_16,_8,_2048>>>;
using BLayout = Layout<Shape <_32,Shape <_32, _2>>,
Stride<_32,Stride< _1,_1024>>>;
using ALayout = Layout<Shape<Shape<_4,_8>,Shape<_32,_2,_2>>,
Stride<Stride<_512,_1>,Stride<_16,_8,_2048>>>;
using BLayout = Layout<Shape<Shape <_4,_8>,Shape<_32,_2>>,
Stride<Stride<_256,_1>,Stride< _8,_1024>>>;
using CLayout = SM80_16x8_Row;
};

template <>
struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_ANDPOPC>
:MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> {};

template<>
struct MMA_Traits<SM80_8x8x128_S32U1U1S32_TN_XORPOPC>
{
using ValTypeD = int32_t;
using ValTypeA = cute::uint1b_t;
using ValTypeB = cute::uint1b_t;
using ValTypeC = int32_t;

using Shape_MNK = Shape<_8,_8,_128>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape<Shape<_4,_8>,_32>,
Stride<Stride<_256,_1>,_8>>;
using BLayout = Layout<Shape<Shape<_4,_8>,_32>,
Stride<Stride<_256,_1>,_8>>;
using CLayout = SM80_8x8_Row;
};

template <>
struct MMA_Traits<SM80_8x8x128_S32U1U1S32_TN_ANDPOPC>
:MMA_Traits<SM80_8x8x128_S32U1U1S32_TN_XORPOPC> {};

template<>
struct MMA_Traits<SM80_16x8x128_S32U1U1S32_TN_XORPOPC>
{
using ValTypeD = int32_t;
using ValTypeA = cute::uint1b_t;
using ValTypeB = cute::uint1b_t;
using ValTypeC = int32_t;

using Shape_MNK = Shape<_16,_8,_128>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape<Shape<_4,_8>,Shape<_32,_2>>,
Stride<Stride<_512,_1>,Stride<Stride<_16,_8>>>>;
using BLayout = Layout<Shape <Shape<_4,_8>,_32>,
Stride<Stride<_256,_1>,_8>>;
using CLayout = SM80_16x8_Row;
};

template <>
struct MMA_Traits<SM80_16x8x128_S32U1U1S32_TN_ANDPOPC>
:MMA_Traits<SM80_16x8x128_S32U1U1S32_TN_XORPOPC> {};

} // end namespace cute
5 changes: 5 additions & 0 deletions include/cute/container/array_subbyte.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,11 @@ print(subbyte_iterator<T> const& x) {
printf("subptr[%db](%p.%u)", int(sizeof_bits_v<T>), x.ptr_, x.idx_);
}

template <class T>
CUTE_HOST_DEVICE void
print(subbyte_reference<T> const& x) {
print(x.get());
}
//
// array_subbyte
// Statically sized array for non-byte-aligned data types
Expand Down

0 comments on commit a424ca6

Please sign in to comment.