forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsm100_mma_array_warpspecialized.hpp
864 lines (741 loc) · 37.5 KB
/
sm100_mma_array_warpspecialized.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/detail/collective.hpp"
#include "cutlass/detail/cluster.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/trace.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/cuda_host_adapter.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/arch/cluster_sm90.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one
template <
int Stages,
int SchedulerPipelineStageCount,
int AccumulatorPipelineStageCount,
class ClusterShape, // Static cluster shape or dynamic (int, int, _1)
class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
class ElementA_,
class StrideA_,
class ElementB_,
class StrideB_,
class TiledMma_,
class GmemTiledCopyA_,
class SmemLayoutAtomA_,
class SmemCopyAtomA_,
class TransformA_,
class GmemTiledCopyB_,
class SmemLayoutAtomB_,
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMma<
MainloopSm100ArrayTmaUmmaWarpSpecialized<
Stages,
SchedulerPipelineStageCount,
AccumulatorPipelineStageCount,
ClusterShape>,
TileShape_,
ElementA_,
StrideA_,
ElementB_,
StrideB_,
TiledMma_,
GmemTiledCopyA_,
SmemLayoutAtomA_,
SmemCopyAtomA_,
TransformA_,
GmemTiledCopyB_,
SmemLayoutAtomB_,
SmemCopyAtomB_,
TransformB_>
{
//
// Type Aliases
//
using TiledMma = TiledMma_;
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecialized<
Stages,
SchedulerPipelineStageCount,
AccumulatorPipelineStageCount,
ClusterShape>;
using TileShape = TileShape_;
static constexpr bool IsDynamicCluster = not cute::is_static_v<ClusterShape>;
CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})),
"Static cluster shape used: TileShape should be evenly divided by TiledMma");
using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{}));
// Define A and B block shapes for reduced size TMA_LOADs
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{}))));
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{}))));
using ElementA = ElementA_;
using ElementAMma = typename TiledMma::ValTypeA;
using StrideA = StrideA_;
using InternalStrideA = cute::remove_pointer_t<StrideA>;
using ElementB = ElementB_;
using ElementBMma = typename TiledMma::ValTypeB;
using StrideB = StrideB_;
using InternalStrideB = cute::remove_pointer_t<StrideB>;
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
(!IsRuntimeDataTypeA && !IsRuntimeDataTypeB),
"ElementA and ElementB should be both runtime or both static.");
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using MainloopPipeline = cutlass::PipelineTmaUmmaAsync<
DispatchPolicy::Stages,
ClusterShape,
AtomThrShapeMNK>;
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)");
static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)");
static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
// Tile along K mode first before tiling over MN. PIPE mode last as usual.
// This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs.
// (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE)
using SmemLayoutA = decltype(UMMA::tile_to_mma_shape(
SmemLayoutAtomA{},
append(MmaShapeA_MK{}, Int<DispatchPolicy::Stages>{}),
cute::conditional_t<cutlass::gemm::detail::is_mn_major<InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
// (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE)
using SmemLayoutB = decltype(UMMA::tile_to_mma_shape(
SmemLayoutAtomB{},
append(MmaShapeB_NK{}, Int<DispatchPolicy::Stages>{}),
cute::conditional_t<cutlass::gemm::detail::is_mn_major<InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
static_assert(cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
static_assert(
(size(AtomThrShapeMNK{}) == 1 &&
(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)) ||
(size(AtomThrShapeMNK{}) == 2 &&
(cute::is_same_v<GmemTiledCopyA, SM100_TMA_2SM_LOAD> || cute::is_same_v<GmemTiledCopyA, SM100_TMA_2SM_LOAD_MULTICAST>)),
"GmemTiledCopy - invalid TMA copy atom specified.");
static_assert(
(size(AtomThrShapeMNK{}) == 1 &&
(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)) ||
(size(AtomThrShapeMNK{}) == 2 &&
(cute::is_same_v<GmemTiledCopyB, SM100_TMA_2SM_LOAD> || cute::is_same_v<GmemTiledCopyB, SM100_TMA_2SM_LOAD_MULTICAST>)),
"GmemTiledCopy - invalid TMA copy atom specified.");
using TmaInternalElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, cutlass::tfloat32_t, ElementAMma>;
using TmaInternalElementB = cute::conditional_t<cute::is_same_v<ElementB, float>, cutlass::tfloat32_t, ElementBMma>;
using SmemAllocTypeA = cute::conditional_t<cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
using SmemAllocTypeB = cute::conditional_t<cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
using BitTypeElementA = uint_bit_t<cute::sizeof_bits_v<ElementA>>;
using BitTypeElementB = uint_bit_t<cute::sizeof_bits_v<ElementB>>;
using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA, BitTypeElementA, ElementA>;
using ArrayElementB = cute::conditional_t<IsRuntimeDataTypeB, BitTypeElementB, ElementB>;
using RuntimeDataTypeA = cute::conditional_t<IsRuntimeDataTypeA, cute::UMMA::MXF8F6F4Format, void*>;
using RuntimeDataTypeB = cute::conditional_t<IsRuntimeDataTypeB, cute::UMMA::MXF8F6F4Format, void*>;
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128, _0> {
cute::ArrayEngine<SmemAllocTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::ArrayEngine<SmemAllocTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
} tensors;
struct TensorMapStorage : cute::aligned_struct<128, _0> {
cute::TmaDescriptor smem_tensormap_A;
cute::TmaDescriptor smem_tensormap_B;
} tensormaps;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
// Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them.
using TensorStorage = typename SharedStorage::TensorStorage;
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
static constexpr uint32_t TmaTransactionBytes =
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>) +
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v<ElementB>);
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<InternalStrideA, StrideA>;
// Host side kernel arguments
struct Arguments {
ArrayElementA const** ptr_A{nullptr};
StrideA dA{};
ArrayElementB const** ptr_B{nullptr};
StrideB dB{};
RuntimeDataTypeA runtime_data_type_a{};
RuntimeDataTypeB runtime_data_type_b{};
};
// Device side kernel params
struct Params {
using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return<IsDynamicCluster>(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})),
make_tile(typename TiledMma::AtomThrID{})));
using TMA_A = decltype(make_tma_atom_A_sm100<TmaInternalElementA>(
GmemTiledCopyA{},
make_tensor(recast_ptr<TmaInternalElementA>(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}),
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
ClusterLayout_VMNK{})
);
using TMA_B = decltype(make_tma_atom_B_sm100<TmaInternalElementB>(
GmemTiledCopyB{},
make_tensor(recast_ptr<TmaInternalElementB>(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}),
SmemLayoutB{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
ClusterLayout_VMNK{})
);
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_A tma_load_a_fallback;
TMA_B tma_load_b_fallback;
dim3 cluster_shape_fallback;
RuntimeDataTypeA runtime_data_type_a;
RuntimeDataTypeB runtime_data_type_b;
cute::TmaDescriptor* tensormaps;
ArrayElementA const** ptr_A;
StrideA dA;
ArrayElementB const** ptr_B;
StrideB dB;
};
CUTLASS_DEVICE
CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster)
: cluster_shape_(cluster_shape)
, block_rank_in_cluster_(block_rank_in_cluster) {
if constexpr (IsDynamicCluster) {
const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y);
observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a;
observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b;
}
else {
observed_tma_load_a_ = ¶ms.tma_load_a;
observed_tma_load_b_ = ¶ms.tma_load_b;
}
}
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(
ProblemShape problem_shapes,
Arguments const& args,
void* workspace,
cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) {
// These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
// These will be replaced with correct values before the initial tma load.
auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1));
auto init_M = get<0>(init_shape);
auto init_N = get<1>(init_shape);
auto init_K = get<2>(init_shape);
auto init_L = get<3>(init_shape);
// Tensor pointers will be fixed before the first access
TmaInternalElementA const* ptr_A_first_batch = nullptr;
TmaInternalElementB const* ptr_B_first_batch = nullptr;
InternalStrideA stride_a;
InternalStrideB stride_b;
if constexpr (IsGroupedGemmKernel) {
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
stride_a = InternalStrideA{};
stride_b = InternalStrideB{};
}
else {
// Tensor shapes for Ptr-Array are initialized correctly only here.
auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0);
init_M = get<0>(problem_shape_MNK);
init_N = get<1>(problem_shape_MNK);
init_K = get<2>(problem_shape_MNK);
stride_a = args.dA;
stride_b = args.dB;
}
// Batches/Groups are managed by using appropriate pointers to input matrices.
Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a));
Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b));
auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape);
// Cluster layout for TMA construction
auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{}));
auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback);
auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{}));
typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100<TmaInternalElementA>(
GmemTiledCopyA{},
tensor_a,
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
cluster_layout_vmnk);
typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100<TmaInternalElementB>(
GmemTiledCopyB{},
tensor_b,
SmemLayoutB{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
cluster_layout_vmnk);
typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100<TmaInternalElementA>(
GmemTiledCopyA{},
tensor_a,
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
cluster_layout_vmnk_fallback);
typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100<TmaInternalElementB>(
GmemTiledCopyB{},
tensor_b,
SmemLayoutB{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
cluster_layout_vmnk_fallback);
return {
tma_load_a,
tma_load_b,
tma_load_a_fallback,
tma_load_b_fallback,
hw_info.cluster_shape_fallback,
args.runtime_data_type_a,
args.runtime_data_type_b,
reinterpret_cast<cute::TmaDescriptor*>(workspace),
reinterpret_cast<ArrayElementA const**>(args.ptr_A),
args.dA,
reinterpret_cast<ArrayElementB const**>(args.ptr_B),
args.dB
};
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
constexpr uint32_t NumInputTensors = 2;
constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
return (NumInputTensors * SizeOfCuTensorMap * sm_count);
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
template<class ProblemShape>
static bool
can_implement(
ProblemShape problem_shapes,
[[maybe_unused]] Arguments const& args) {
static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4<TiledMma, ElementA, ElementB>();
constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits<ElementA, IsF8F6F4>();
constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits<ElementB, IsF8F6F4>();
constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits<ElementA>::value;
constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits<ElementB>::value;
bool implementable = true;
if (problem_shapes.is_host_problem_shape_available()) {
// Check alignment for all problem sizes
for (int i = 0; i < problem_shapes.groups(); i++) {
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
auto [M,N,K,L] = problem_shape_MNKL;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), InternalStrideA{});
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), InternalStrideB{});
}
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
/// Construct A Single Stage's Accumulator Shape
CUTLASS_DEVICE auto
partition_accumulator_shape() {
auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N)
return acc_shape;
}
template <class FrgEngine, class FrgLayout>
CUTLASS_DEVICE auto
slice_accumulator(cute::Tensor<FrgEngine, FrgLayout> const& accumulators, int stage) {
return accumulators(_,_,_,stage);
}
/// Set up the data needed by this collective for load.
/// Return tuple element contain
/// gA_mkl - The tiled tma tensor for input A
/// gB_nkl - The tiled tma tensor for input B
/// tAsA - partitioned smem tensor for A
/// tBsB - partitioned smem tensor for B
/// mcast_mask_a - tma multicast mask for A
/// mcast_mask_b - tma multicast mask for B
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto
load_init(
ProblemShape_MNKL const& problem_shape_MNKL,
Params const& params,
TensorStorage& shared_tensors,
TensorMapStorage& shared_tensormaps,
int32_t const sm_count, int32_t const sm_idx,
[[maybe_unused]] int32_t init_group) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M,N,K,L] = problem_shape_MNKL;
// Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads
// we are managing TMA descriptors to change batches, we need to neglect the L mode
const int32_t mock_L = 1;
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,mock_L));
Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L));
// Tile the tensors and defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l)
// Partition for this CTA
ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{}));
Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l)
Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l)
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE)
// Define the CTA-in-Cluster Layout and Coord
Layout cta_layout_mnk = make_layout(cluster_shape_);
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{}));
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_);
// Project the cta_layout for tma_a along the n-modes
auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_,
get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl));
// Project the cta_layout for tma_b along the m-modes
auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_,
get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)),
group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl));
// TMA Multicast Masks
uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
// Fetch a copy of tensormaps for the CTA from Params
auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx);
return cute::make_tuple(
gA_mkl, gB_nkl, // for scheduler
tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values
mcast_mask_a, mcast_mask_b, // multicast masks
input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy)
}
/// Set up the data needed by this collective for mma compute.
template <class FrgEngine, class FrgLayout>
CUTLASS_DEVICE auto
mma_init(
Params const& params,
[[maybe_unused]] cute::Tensor<FrgEngine, FrgLayout> const& accumulators,
TensorStorage& shared_tensors,
[[maybe_unused]] uint32_t const tmem_nonaccum_offset) const {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
// Allocate "fragments/descriptors" for A and B matrices
Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sB));
TiledMma tiled_mma;
if constexpr (IsRuntimeDataType) {
// Update instruction descriptor according to runtime argument.
// Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe.
tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111;
tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111;
}
return cute::make_tuple(tiled_mma, tCrA, tCrB);
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <
class GTensorA, class GTensorB,
class GTensorPartitionedA, class GTensorPartitionedB,
class STensorA, class STensorB,
class TensorMapA, class TensorMapB,
class TileCoordMNKL,
class KTileIterator
>
CUTLASS_DEVICE auto
load(
Params const& params,
MainloopPipeline mainloop_pipeline,
MainloopPipelineState mainloop_pipe_producer_state,
cute::tuple<GTensorA, GTensorB,
GTensorPartitionedA, GTensorPartitionedB,
STensorA, STensorB,
uint16_t, uint16_t,
cute::tuple<TensorMapA, TensorMapB>> const& load_inputs,
TileCoordMNKL const& cta_coord_mnkl,
KTileIterator k_tile_iter, int k_tile_count,
bool did_batch_change) {
auto [unused_gA, unused_gB,
tAgA_mkl, tBgB_nkl, tAsA, tBsB,
mcast_mask_a, mcast_mask_b,
input_tensormaps] = load_inputs;
// Check to see if tensormaps have been replaced in gmem
if (did_batch_change) {
tensormaps_fence_acquire(input_tensormaps);
}
// slice out the work coord from partitioned tensors
Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl));
Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl));
auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state);
// Issue the Mainloop loads
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
// LOCK mainloop_pipe_producer_state for _writing_
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token);
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state);
int write_stage = mainloop_pipe_producer_state.index();
++mainloop_pipe_producer_state;
barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state);
if (cute::elect_one_sync()) {
copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage));
copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage));
}
--k_tile_count;
++k_tile_iter;
}
return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter);
}
/// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster
CUTLASS_DEVICE void
load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) {
// Issue the epilogue waits
// This helps avoid early exit of ctas in Cluster
// Waits for all stages to either be released (all
// Consumer UNLOCKs), or if the stage was never used
// then would just be acquired since the phase was
// still inverted from make_producer_start_state
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <
class AccumulatorPipeline,
class FrgEngine, class FrgLayout,
class FragmentA, class FragmentB,
class CtaTileCoord
>
CUTLASS_DEVICE auto
mma(cute::tuple<MainloopPipeline,
AccumulatorPipeline> pipelines,
cute::tuple<MainloopPipelineState,
typename AccumulatorPipeline::PipelineState> pipeline_states,
cute::Tensor<FrgEngine, FrgLayout>& accumulators,
cute::tuple<TiledMma, FragmentA, FragmentB> const& mma_inputs,
CtaTileCoord cta_tile_coord,
int k_tile_count
) {
static_assert(is_tmem<FrgEngine>::value, "Accumulator must be tmem resident.");
static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)");
auto [tiled_mma, tCrA, tCrB] = mma_inputs;
auto [mainloop_pipeline, accumulator_pipeline] = pipelines;
auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states;
uint32_t skip_wait = k_tile_count <= 0;
auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
//
// PIPELINED MAIN LOOP
//
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
// WAIT on mainloop_pipe_consumer_state until its data are available
// (phase bit flips from mainloop_pipe_consumer_state.phase() value)
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
// Compute on k_tile
int read_stage = mainloop_pipe_consumer_state.index();
// Save current mainlop pipeline read state
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
// Advance mainloop_pipe
++mainloop_pipe_consumer_state;
--k_tile_count;
skip_wait = k_tile_count <= 0;
// Peek at next iteration
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
// Unroll the K mode manually so we can set scale C to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma,
tCrA(_,_,k_block,read_stage),
tCrB(_,_,k_block,read_stage),
accumulators);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
}
return mainloop_pipe_consumer_state;
}
//
// Methods to perform different parts of TMA/Tensormap modifications
//
CUTLASS_DEVICE auto
tensormaps_init(
Params const& mainloop_params,
TensorMapStorage& shared_tensormaps,
int32_t const sm_count,
int32_t const sm_idx) const {
cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps;
cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx];
cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count];
if (cute::elect_one_sync()) {
// Bringing tensormaps from params to smem for modification later
Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{});
Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{});
Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{});
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(sA_tensormap));
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(sB_tensormap));
}
__syncwarp();
return cute::make_tuple(tma_desc_a, tma_desc_b);
}
// Replace address for the global tensor (to be done by single thread)
CUTLASS_DEVICE
void
tensormaps_replace_global_address(
TensorMapStorage& shared_tensormaps,
Params const& mainloop_params,
int32_t next_batch) {
// Replacing global_address for the next batch
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A,
mainloop_params.ptr_A[next_batch]);
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B,
mainloop_params.ptr_B[next_batch]);
}
// Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread)
template <class ProblemShape_MNKL>
CUTLASS_DEVICE
void
tensormaps_replace_global_tensor_properties(
TensorMapStorage& shared_tensormaps,
Params const& mainloop_params,
int32_t next_group,
ProblemShape_MNKL problem_shape_mnkl) {
const uint32_t M = get<0>(problem_shape_mnkl);
const uint32_t N = get<1>(problem_shape_mnkl);
const uint32_t K = get<2>(problem_shape_mnkl);
// Replace all dims for consistency
constexpr int MaxTensorRank = 5;
cute::array<uint32_t, MaxTensorRank> prob_shape_A = {1,1,1,1,1};
cute::array<uint64_t, MaxTensorRank> prob_stride_A = {0,0,0,0,0};
cute::array<uint32_t, MaxTensorRank> prob_shape_B = {1,1,1,1,1};
cute::array<uint64_t, MaxTensorRank> prob_stride_B = {0,0,0,0,0};
TmaInternalElementA const* ptr_A = nullptr;
Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]);
TmaInternalElementB const* ptr_B = nullptr;
Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]);
cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a,
prob_shape_A, prob_stride_A);
cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b,
prob_shape_B, prob_stride_B);
// Convert strides to byte strides
for (uint64_t& stride : prob_stride_A) {
stride = (stride * sizeof_bits_v<TmaInternalElementA>) / 8;
}
for (uint64_t& stride : prob_stride_B) {
stride = (stride * sizeof_bits_v<TmaInternalElementB>) / 8;
}
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A,
prob_shape_A,
prob_stride_A);
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B,
prob_shape_B,
prob_stride_B);
}
// The entire warp must call this function collectively (that is, the instructions are aligned)
template <class TensorMapA, class TensorMapB, class ProblemShape>
CUTLASS_DEVICE
void
tensormaps_perform_update(
TensorMapStorage& shared_tensormaps,
Params const& mainloop_params,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps,
ProblemShape problem_shape,
int32_t next_batch) {
if (cute::elect_one_sync()) {
// Replacing global_address for the next batch
tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch);
if constexpr (IsGroupedGemmKernel) {
auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1);
// Replacing global dims and strides for the next batch
tensormaps_replace_global_tensor_properties(shared_tensormaps,
mainloop_params, next_batch, problem_shape_MNKL);
}
}
// Ensure warp is converged before issuing tensormap fence release
__syncwarp();
// Entire warp must do this (ie its aligned)
tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps);
}
template <class TensorMapA, class TensorMapB>
CUTLASS_DEVICE
void
tensormaps_cp_fence_release (
TensorMapStorage& shared_tensormaps,
cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
if (cute::elect_one_sync()) {
cute::tma_desc_commit_group();
cute::tma_desc_wait_group();
}
// Entire warp must do this (i.e. it's aligned)
tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A);
tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B);
}
// The entire warp must call this function collectively (that is, the instructions are aligned)
template <class TensorMapA, class TensorMapB>
CUTLASS_DEVICE
void
tensormaps_fence_acquire(cute::tuple<TensorMapA, TensorMapB> const& input_tensormaps) {
cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps));
cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
}
private:
typename Params::TMA_A const* observed_tma_load_a_{nullptr};
typename Params::TMA_B const* observed_tma_load_b_{nullptr};
ClusterShape cluster_shape_;
uint32_t block_rank_in_cluster_;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////