@@ -75,30 +75,37 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
75
75
76
76
void VoteBallotOp::print (OpAsmPrinter &p) { printNVVMIntrinsicOp (p, *this ); }
77
77
78
- // This verifier is shared across:
79
- // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and
80
- // CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops.
78
+ // This verifier is shared among the following Ops:
79
+ // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
80
+ // CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
81
+ // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
81
82
static LogicalResult CpAsyncBulkTensorCommonVerifier (size_t tensorDims,
83
+ bool isIm2Col,
82
84
size_t numIm2ColOffsets,
83
85
Location loc) {
84
86
if (tensorDims < 1 || tensorDims > 5 )
85
87
return emitError (loc, " expects coordinates between 1 to 5 dimension" );
86
88
87
- if (numIm2ColOffsets) {
89
+ // For Im2Col mode, there are two constraints:
90
+ if (isIm2Col) {
91
+ // 1. Tensor must always be at least 3-d.
88
92
if (tensorDims < 3 )
89
93
return emitError (
90
94
loc,
91
95
" to use im2col mode, the tensor has to be at least 3-dimensional" );
92
- if (tensorDims != (numIm2ColOffsets + 2 ))
96
+ // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
97
+ if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2 )))
93
98
return emitError (
94
99
loc, " im2col offsets must be 2 less than number of coordinates" );
95
100
}
96
101
return success ();
97
102
}
98
103
99
104
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify () {
100
- return CpAsyncBulkTensorCommonVerifier (getCoordinates ().size (),
101
- getIm2colOffsets ().size (), getLoc ());
105
+ size_t numIm2ColOffsets = getIm2colOffsets ().size ();
106
+ bool isIm2Col = numIm2ColOffsets > 0 ;
107
+ return CpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
108
+ numIm2ColOffsets, getLoc ());
102
109
}
103
110
104
111
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify () {
@@ -119,8 +126,16 @@ LogicalResult CpAsyncOp::verify() {
119
126
}
120
127
121
128
LogicalResult CpAsyncBulkTensorPrefetchOp::verify () {
122
- return CpAsyncBulkTensorCommonVerifier (getCoordinates ().size (),
123
- getIm2colOffsets ().size (), getLoc ());
129
+ size_t numIm2ColOffsets = getIm2colOffsets ().size ();
130
+ bool isIm2Col = numIm2ColOffsets > 0 ;
131
+ return CpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
132
+ numIm2ColOffsets, getLoc ());
133
+ }
134
+
135
+ LogicalResult CpAsyncBulkTensorReduceOp::verify () {
136
+ bool isIm2Col = (getMode () == TMAStoreMode::IM2COL);
137
+ return CpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col, 0 ,
138
+ getLoc ());
124
139
}
125
140
126
141
// Given the element type of an operand and whether or not it is an accumulator,
@@ -1094,6 +1109,55 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1094
1109
}
1095
1110
}
1096
1111
1112
+ #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, dim, mode ) \
1113
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1114
+
1115
+ #define CP_ASYNC_BULK_TENSOR_REDUCE (op, dim, is_im2col ) \
1116
+ is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1117
+ : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1118
+
1119
+ #define GET_CP_ASYNC_BULK_TENSOR_ID (op, dims, is_im2col ) \
1120
+ [&]() -> auto { \
1121
+ switch (dims) { \
1122
+ case 1 : \
1123
+ return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 1 , tile); \
1124
+ case 2 : \
1125
+ return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 2 , tile); \
1126
+ case 3 : \
1127
+ return CP_ASYNC_BULK_TENSOR_REDUCE (op, 3 , is_im2col); \
1128
+ case 4 : \
1129
+ return CP_ASYNC_BULK_TENSOR_REDUCE (op, 4 , is_im2col); \
1130
+ case 5 : \
1131
+ return CP_ASYNC_BULK_TENSOR_REDUCE (op, 5 , is_im2col); \
1132
+ default : \
1133
+ llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorReduceOp." ); \
1134
+ } \
1135
+ }()
1136
+
1137
+ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID (
1138
+ int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1139
+ using RedTy = NVVM::TMAReduxKind;
1140
+ switch (kind) {
1141
+ case RedTy::ADD:
1142
+ return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_add, tensorDims, isIm2Col);
1143
+ case RedTy::MIN:
1144
+ return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_min, tensorDims, isIm2Col);
1145
+ case RedTy::MAX:
1146
+ return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_max, tensorDims, isIm2Col);
1147
+ case RedTy::INC:
1148
+ return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_inc, tensorDims, isIm2Col);
1149
+ case RedTy::DEC:
1150
+ return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_dec, tensorDims, isIm2Col);
1151
+ case RedTy::AND:
1152
+ return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_and, tensorDims, isIm2Col);
1153
+ case RedTy::OR:
1154
+ return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_or, tensorDims, isIm2Col);
1155
+ case RedTy::XOR:
1156
+ return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_xor, tensorDims, isIm2Col);
1157
+ }
1158
+ llvm_unreachable (" Invalid Reduction Op for CpAsyncBulkTensorReduceOp" );
1159
+ }
1160
+
1097
1161
// ===----------------------------------------------------------------------===//
1098
1162
// NVVMDialect initialization, type parsing, and registration.
1099
1163
// ===----------------------------------------------------------------------===//
0 commit comments