forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsm100_mma_warpspecialized.hpp
723 lines (614 loc) · 29.9 KB
/
sm100_mma_warpspecialized.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/detail/collective.hpp"
#include "cutlass/detail/cluster.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/trace.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/detail/sm100_tmem_helper.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/arch/cluster_sm90.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one
template <
int Stages,
int SchedulerPipelineStageCount,
int AccumulatorPipelineStageCount,
class ClusterShape, // Static cluster shape or dynamic (int, int, _1)
class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
class ElementA_,
class StrideA_,
class ElementB_,
class StrideB_,
class TiledMma_,
class GmemTiledCopyA_,
class SmemLayoutAtomA_,
class SmemCopyAtomA_,
class TransformA_,
class GmemTiledCopyB_,
class SmemLayoutAtomB_,
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMma<
MainloopSm100TmaUmmaWarpSpecialized<
Stages,
SchedulerPipelineStageCount,
AccumulatorPipelineStageCount,
ClusterShape>,
TileShape_,
ElementA_,
StrideA_,
ElementB_,
StrideB_,
TiledMma_,
GmemTiledCopyA_,
SmemLayoutAtomA_,
SmemCopyAtomA_,
TransformA_,
GmemTiledCopyB_,
SmemLayoutAtomB_,
SmemCopyAtomB_,
TransformB_>
{
//
// Type Aliases
//
using TiledMma = TiledMma_;
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecialized<
Stages,
SchedulerPipelineStageCount,
AccumulatorPipelineStageCount,
ClusterShape>;
using TileShape = TileShape_;
static constexpr bool IsDynamicCluster = not cute::is_static_v<ClusterShape>;
CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})),
"Static cluster shape used: TileShape should be evenly divided by TiledMma");
using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{}));
// Define A and B block shapes for reduced size TMA_LOADs
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{}))));
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{}))));
using ElementA = ElementA_;
using ElementAMma = typename TiledMma::ValTypeA;
using StrideA = StrideA_;
using ElementB = ElementB_;
using ElementBMma = typename TiledMma::ValTypeB;
using StrideB = StrideB_;
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
(!IsRuntimeDataTypeA && !IsRuntimeDataTypeB),
"ElementA and ElementB should be both runtime or both static.");
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using MainloopPipeline = cutlass::PipelineTmaUmmaAsync<
DispatchPolicy::Stages,
ClusterShape,
AtomThrShapeMNK>;
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)");
static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)");
static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
// Tile along K mode first before tiling over MN. PIPE mode last as usual.
// This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs.
// (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE)
using SmemLayoutA = decltype(UMMA::tile_to_mma_shape(
SmemLayoutAtomA{},
append(MmaShapeA_MK{}, Int<DispatchPolicy::Stages>{}),
cute::conditional_t<cutlass::gemm::detail::is_mn_major<StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
// (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE)
using SmemLayoutB = decltype(UMMA::tile_to_mma_shape(
SmemLayoutAtomB{},
append(MmaShapeB_NK{}, Int<DispatchPolicy::Stages>{}),
cute::conditional_t<cutlass::gemm::detail::is_mn_major<StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
static_assert(cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
static_assert(
(size(AtomThrShapeMNK{}) == 1 &&
(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)) ||
(size(AtomThrShapeMNK{}) == 2 &&
(cute::is_same_v<GmemTiledCopyA, SM100_TMA_2SM_LOAD> || cute::is_same_v<GmemTiledCopyA, SM100_TMA_2SM_LOAD_MULTICAST>)),
"GmemTiledCopy - invalid TMA copy atom specified.");
static_assert(
(size(AtomThrShapeMNK{}) == 1 &&
(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)) ||
(size(AtomThrShapeMNK{}) == 2 &&
(cute::is_same_v<GmemTiledCopyB, SM100_TMA_2SM_LOAD> || cute::is_same_v<GmemTiledCopyB, SM100_TMA_2SM_LOAD_MULTICAST>)),
"GmemTiledCopy - invalid TMA copy atom specified.");
using TmaInternalElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, cutlass::tfloat32_t, ElementAMma>;
using TmaInternalElementB = cute::conditional_t<cute::is_same_v<ElementB, float>, cutlass::tfloat32_t, ElementBMma>;
using SmemAllocTypeA = cute::conditional_t<cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
using SmemAllocTypeB = cute::conditional_t<cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
using BitTypeElementA = cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>;
using BitTypeElementB = cute::uint_bit_t<cute::sizeof_bits_v<ElementB>>;
using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA, BitTypeElementA, ElementA>;
using ArrayElementB = cute::conditional_t<IsRuntimeDataTypeB, BitTypeElementB, ElementB>;
using RuntimeDataTypeA = cute::conditional_t<IsRuntimeDataTypeA, cute::UMMA::MXF8F6F4Format, void*>;
using RuntimeDataTypeB = cute::conditional_t<IsRuntimeDataTypeB, cute::UMMA::MXF8F6F4Format, void*>;
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128, _0> {
cute::ArrayEngine<SmemAllocTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::ArrayEngine<SmemAllocTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
// Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them.
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
static constexpr uint32_t TmaTransactionBytes =
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>) +
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v<ElementB>);
template<class AccTensor>
struct TmemStorage {
AccTensor accumulators;
};
template<
class KTileCount,
class GTensorPartitionedA, class GTensorPartitionedB,
class STensorA, class STensorB
>
struct LoadParams {
// for scheduler
KTileCount k_tiles;
// for input tensor values
GTensorPartitionedA tAgA_mkl;
GTensorPartitionedB tBgB_nkl;
STensorA tAsA;
STensorB tBsB;
// the TMA multicast masks
uint16_t mcast_mask_a;
uint16_t mcast_mask_b;
CUTLASS_DEVICE
LoadParams (
KTileCount k_tiles_,
GTensorPartitionedA tAgA_mkl_, GTensorPartitionedB tBgB_nkl_,
STensorA tAsA_, STensorB tBsB_,
uint16_t mcast_mask_a_, uint16_t mcast_mask_b_)
: k_tiles(k_tiles_)
, tAgA_mkl(tAgA_mkl_), tBgB_nkl(tBgB_nkl_)
, tAsA(tAsA_), tBsB(tBsB_)
, mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) {}
};
template<class FragmentA, class FragmentB>
struct MmaParams {
TiledMma tiled_mma;
FragmentA tCrA;
FragmentB tCrB;
CUTLASS_DEVICE
MmaParams (
TiledMma tiled_mma_,
FragmentA tCrA_, FragmentB tCrB_)
: tiled_mma(tiled_mma_)
, tCrA(tCrA_), tCrB(tCrB_) {}
};
// Host side kernel arguments
struct Arguments {
ArrayElementA const* ptr_A{nullptr};
StrideA dA{};
ArrayElementB const* ptr_B{nullptr};
StrideB dB{};
RuntimeDataTypeA runtime_data_type_a{};
RuntimeDataTypeB runtime_data_type_b{};
};
// Device side kernel params
struct Params {
using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return<IsDynamicCluster>(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})),
make_tile(typename TiledMma::AtomThrID{})));
using TMA_A = decltype(make_tma_atom_A_sm100<TmaInternalElementA>(
GmemTiledCopyA{},
make_tensor(recast_ptr<TmaInternalElementA>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
ClusterLayout_VMNK{})
);
using TMA_B = decltype(make_tma_atom_B_sm100<TmaInternalElementB>(
GmemTiledCopyB{},
make_tensor(recast_ptr<TmaInternalElementB>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
ClusterLayout_VMNK{})
);
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_A tma_load_a_fallback;
TMA_B tma_load_b_fallback;
dim3 cluster_shape_fallback;
RuntimeDataTypeA runtime_data_type_a;
RuntimeDataTypeB runtime_data_type_b;
};
CUTLASS_DEVICE
CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster)
: cluster_shape_(cluster_shape)
, block_rank_in_cluster_(block_rank_in_cluster)
, runtime_data_type_a_(params.runtime_data_type_a)
, runtime_data_type_b_(params.runtime_data_type_b) {
if constexpr (IsDynamicCluster) {
const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y);
observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a;
observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b;
}
else {
observed_tma_load_a_ = ¶ms.tma_load_a;
observed_tma_load_b_ = ¶ms.tma_load_b;
}
}
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
[[maybe_unused]] void* workspace,
cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) {
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
auto ptr_A = recast_ptr<TmaInternalElementA>(args.ptr_A);
auto ptr_B = recast_ptr<TmaInternalElementB>(args.ptr_B);
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape);
// Cluster layout for TMA construction
auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{}));
auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback);
auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{}));
typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100<TmaInternalElementA>(
GmemTiledCopyA{},
tensor_a,
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
cluster_layout_vmnk);
typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100<TmaInternalElementB>(
GmemTiledCopyB{},
tensor_b,
SmemLayoutB{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
cluster_layout_vmnk);
typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100<TmaInternalElementA>(
GmemTiledCopyA{},
tensor_a,
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
cluster_layout_vmnk_fallback);
typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100<TmaInternalElementB>(
GmemTiledCopyB{},
tensor_b,
SmemLayoutB{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
cluster_layout_vmnk_fallback);
return {
tma_load_a,
tma_load_b,
tma_load_a_fallback,
tma_load_b_fallback,
hw_info.cluster_shape_fallback,
args.runtime_data_type_a,
args.runtime_data_type_b
};
}
template <class ProblemShape>
static bool
can_implement(
ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4<TiledMma, ElementA, ElementB>();
constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits<ElementA, IsF8F6F4>();
constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits<ElementB, IsF8F6F4>();
constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits<ElementA>::value;
bool implementable = true;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits<ElementB>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE void
prefetch_tma_descriptors() {
cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor());
cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor());
}
/// Construct A Single Stage's Accumulator Shape
CUTLASS_DEVICE static
auto
partition_accumulator_shape() {
auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N)
return acc_shape;
}
template <class TmemStorage>
CUTLASS_DEVICE static
auto
slice_accumulator(TmemStorage tmem_storage, int stage) {
return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage));
}
template<class EpilogueTile, bool IsOverlappingAccum = false>
CUTLASS_DEVICE static
auto
init_tmem_tensors(EpilogueTile epi_tile) {
TiledMma tiled_mma;
auto acc_shape = partition_accumulator_shape();
// ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue.
Tensor accumulators = cutlass::detail::make_sm100_accumulator<AccumulatorPipelineStageCount, IsOverlappingAccum>(
tiled_mma, acc_shape, EpilogueTile{});
TmemStorage<decltype(accumulators)> tmem_storage;
tmem_storage.accumulators = accumulators;
return tmem_storage;
}
template<class AccTensor>
CUTLASS_DEVICE static
void
set_tmem_offsets(TmemStorage<AccTensor>& tmem_storage, uint32_t tmem_base_addr) {
tmem_storage.accumulators.data() = tmem_base_addr;
}
/// Set up the data needed by this collective for load.
/// Return tuple element contain
/// gA_mkl - The tiled tma tensor for input A
/// gB_nkl - The tiled tma tensor for input B
/// tAsA - partitioned smem tensor for A
/// tBsB - partitioned smem tensor for B
/// mcast_mask_a - tma multicast mask for A
/// mcast_mask_b - tma multicast mask for B
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto
load_init(
ProblemShape_MNKL const& problem_shape_MNKL,
TensorStorage& shared_tensors) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M,N,K,L] = problem_shape_MNKL;
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L));
Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L));
// Tile the tensors and defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l)
// Partition for this CTA
ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{}));
Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l)
Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l)
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE)
// Define the CTA-in-cluster Layout and Coord
Layout cta_layout_mnk = make_layout(cluster_shape_);
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{}));
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_);
// Project the cta_layout for tma_a along the n-modes
auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_,
get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl));
// Project the cta_layout for tma_b along the m-modes
auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_,
get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)),
group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl));
// TMA Multicast Masks
uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
LoadParams load_params {
shape<3>(gA_mkl), // for scheduler
tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values
mcast_mask_a, mcast_mask_b // multicast masks
};
return load_params;
}
/// Set up the data needed by this collective for mma compute.
template <class AccTensor>
CUTLASS_DEVICE auto
mma_init(
[[maybe_unused]] TmemStorage<AccTensor> tmem_tensors,
TensorStorage& shared_tensors) const {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
// Allocate "fragments/descriptors" for A and B matrices
Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sB));
TiledMma tiled_mma;
if constexpr (IsRuntimeDataType) {
// Update instruction descriptor according to runtime argument.
// Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe.
tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111;
tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111;
}
MmaParams<decltype(tCrA), decltype(tCrB)> mma_params {
tiled_mma,
tCrA, tCrB
};
return mma_params;
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <
class LoadParams,
class TileCoordMNKL,
class KTileIterator
>
CUTLASS_DEVICE auto
load(
MainloopPipeline mainloop_pipeline,
MainloopPipelineState mainloop_pipe_producer_state,
LoadParams const& load_inputs,
TileCoordMNKL const& cta_coord_mnkl,
KTileIterator k_tile_iter, int k_tile_count) {
auto [unused_k_tiles,
tAgA_mkl, tBgB_nkl, tAsA, tBsB,
mcast_mask_a, mcast_mask_b] = load_inputs;
// slice out the work coord from partitioned tensors
Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl));
Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl));
auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state);
// Issue the Mainloop loads
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
// LOCK mainloop_pipe_producer_state for _writing_
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token);
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state);
int write_stage = mainloop_pipe_producer_state.index();
++mainloop_pipe_producer_state;
barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state);
if (cute::elect_one_sync()) {
copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage));
copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage));
}
--k_tile_count;
++k_tile_iter;
}
return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter);
}
/// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster
CUTLASS_DEVICE void
load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) {
// Issue the epilogue waits
// This helps avoid early exit of ctas in Cluster
// Waits for all stages to either be released (all
// Consumer UNLOCKs), or if the stage was never used
// then would just be acquired since the phase was
// still inverted from make_producer_start_state
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <
class AccumulatorPipeline,
class FrgEngine, class FrgLayout,
class MmaParams,
class CtaTileCoord
>
CUTLASS_DEVICE auto
mma(cute::tuple<MainloopPipeline,
AccumulatorPipeline> pipelines,
cute::tuple<MainloopPipelineState,
typename AccumulatorPipeline::PipelineState> pipeline_states,
cute::tuple<cute::Tensor<FrgEngine, FrgLayout>> const& accumulators_pair,
MmaParams const& mma_inputs,
CtaTileCoord cta_tile_coord,
int k_tile_count
) {
static_assert(is_tmem<FrgEngine>::value, "Accumulator must be tmem resident.");
static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)");
auto accumulators = get<0>(accumulators_pair);
auto [tiled_mma, tCrA, tCrB] = mma_inputs;
auto [mainloop_pipeline, accumulator_pipeline] = pipelines;
auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states;
uint32_t skip_wait = k_tile_count <= 0;
auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
//
// PIPELINED MAIN LOOP
//
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
// WAIT on mainloop_pipe_consumer_state until its data are available
// (phase bit flips from mainloop_pipe_consumer_state.phase() value)
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
// Compute on k_tile
int read_stage = mainloop_pipe_consumer_state.index();
// Save current mainlop pipeline read state
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
// Advance mainloop_pipe
++mainloop_pipe_consumer_state;
--k_tile_count;
skip_wait = k_tile_count <= 0;
// Peek at next iteration
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
// Unroll the K mode manually so we can set scale C to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma,
tCrA(_,_,k_block,read_stage),
tCrB(_,_,k_block,read_stage),
accumulators);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
}
return mainloop_pipe_consumer_state;
}
private:
typename Params::TMA_A const* observed_tma_load_a_{nullptr};
typename Params::TMA_B const* observed_tma_load_b_{nullptr};
RuntimeDataTypeA runtime_data_type_a_{};
RuntimeDataTypeB runtime_data_type_b_{};
ClusterShape cluster_shape_;
uint32_t block_rank_in_cluster_;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////