forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdispatch_policy.hpp
121 lines (100 loc) · 4.95 KB
/
dispatch_policy.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/conv/convolution.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/arch/arch.h"
#include "cute/layout.hpp"
#include "cute/numeric/integral_constant.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
//////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
namespace cutlass::conv {
//////////////////////////////////////////////////////////////////////////////
//
// Policies for categorical dispatch of mainloop against kernel grid schedules
//
struct KernelImplicitTmaWarpSpecializedSm90 : cutlass::gemm::KernelTmaWarpSpecialized { };
struct KernelImplicitTmaWarpSpecializedSm90Cooperative { };
struct KernelImplicitTmaWarpSpecializedSm90Pingpong { };
//
// Collective Mainloop Policies
//
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA
// for fprop
template<
conv::Operator ConvOp_,
int Stages_,
int NumSpatialDimensions_,
class ClusterShape_ = cute::Shape<cute::C<1>,cute::C<1>,cute::C<1>>,
class KernelSchedule = KernelImplicitTmaWarpSpecializedSm90,
int PipelineAsyncMmaStages_ = 1
>
struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm {
static constexpr int Stages = Stages_;
static constexpr int NumSpatialDimensions = NumSpatialDimensions_;
static constexpr Operator ConvOp = ConvOp_;
static constexpr int PipelineAsyncMmaStages = PipelineAsyncMmaStages_;
using ClusterShape = ClusterShape_;
using ArchTag = arch::Sm90;
using Schedule = KernelSchedule;
static_assert(NumSpatialDimensions >= 1);
static_assert(! (cute::is_same_v<KernelSchedule,KernelImplicitTmaWarpSpecializedSm90Cooperative> ||
cute::is_same_v<KernelSchedule,KernelImplicitTmaWarpSpecializedSm90Pingpong>),
"Persistent schedules not support for conv yet.");
};
// SM100 tensor op kernel schedule
struct KernelImplicitTmaWarpSpecializedSm100 { };
// Pseudo-policies for builder auto override that dispatches to the KernelImplicitTmaWarpSpecializedSm100
// but for opting into 1 or 2 SM atoms
struct KernelImplicitTmaWarpSpecialized1SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { };
struct KernelImplicitTmaWarpSpecialized2SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { };
struct KernelStridedDgradTmaWs1SmSm100 { };
struct KernelStridedDgradTmaWs2SmSm100 { };
// n-buffer in smem (Blackwell TMA), pipelined with Blackwell UMMA and TMA, fprop
template<
conv::Operator ConvOp_,
int Stages_,
int NumSpatialDimensions_,
class ClusterShape_ = cute::Shape<cute::C<1>,cute::C<1>,cute::C<1>>
>
struct MainloopSm100TmaUmmaWarpSpecializedImplicitGemm {
static constexpr int Stages = Stages_;
static constexpr int NumSpatialDimensions = NumSpatialDimensions_;
static constexpr Operator ConvOp = ConvOp_;
using ClusterShape = ClusterShape_;
using ArchTag = arch::Sm100;
using Schedule = KernelImplicitTmaWarpSpecializedSm100;
static_assert(NumSpatialDimensions >= 1);
};
//////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::conv
//////////////////////////////////////////////////////////////////////////////