Skip to content

Commit 5cb64c6

Browse files
author
Siyuan Feng
committed
add int support and fix lint
1 parent 3db8d64 commit 5cb64c6

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

src/codegen/codegen_cuda.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -389,11 +389,11 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) {
389389
std::string scope = alloc_storage_scope_.at(buffer);
390390
if (scope.find("wmma.") == 0) {
391391
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
392-
CHECK(op->type.is_float() && op->type.bits() == 16)
393-
<< "Matrix_a and matrix_b only support half type for now";
392+
CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8))
393+
<< "Matrix_a and matrix_b only support half or char or unsigned char type for now";
394394
} else {
395-
CHECK(op->type.is_float() && (op->type.bits() == 16 || op->type.bits() == 32))
396-
<< "Accumulator only support half and float type for now";
395+
CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32))
396+
<< "Accumulator only support half, float and int type for now";
397397
}
398398
constant_size /= 256;
399399
PrintWmmaScope(scope, op->type, buffer, stream);
@@ -511,7 +511,8 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*
511511
PrintConst(op, os, this);
512512
}
513513

514-
void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, const Variable* variable, std::ostream &os) {
514+
void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t,
515+
const Variable* variable, std::ostream &os) {
515516
std::stringstream type;
516517
PrintType(t, type);
517518
std::string shape_str = fragment_shapes[variable];
@@ -527,7 +528,8 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, const Variabl
527528
<< shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
528529
} else if (scope == "wmma.accumulator") {
529530
need_mma_h_ = true;
530-
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", "<< type.str() << ">";
531+
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, "
532+
<< shape_str << ", "<< type.str() << ">";
531533
}
532534
}
533535

src/codegen/codegen_cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/codegen.h>
2929
#include <tvm/packed_func_ext.h>
3030
#include <string>
31+
#include <unordered_map>
3132
#include "codegen_c.h"
3233

3334
namespace tvm {

src/pass/infer_fragment.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class FragmentGetter : public IRVisitor {
119119

120120
class FragmentChecker : public IRVisitor {
121121
public:
122-
FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
122+
explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
123123

124124
void Visit_(const Call* op) final {
125125
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
@@ -137,22 +137,22 @@ class FragmentChecker : public IRVisitor {
137137
CHECK(CheckShape(buffer_var_d, buffer_var_c));
138138
}
139139
}
140+
140141
private:
141142
bool CheckShape(const Variable* buffer1, const Variable* buffer2) {
142143
CHECK(fragment_getter.fragments.count(buffer1));
143144
CHECK(fragment_getter.fragments.count(buffer2));
144145
FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
145146
FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2);
146147
return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k;
147-
148148
}
149-
const FragmentGetter &fragment_getter;
150149

150+
const FragmentGetter &fragment_getter;
151151
};
152152

153153
class InferFragmenter : public IRMutator {
154154
public:
155-
InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
155+
explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
156156

157157
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
158158
Stmt stmt = IRMutator::Mutate_(op, s);
@@ -174,6 +174,7 @@ class InferFragmenter : public IRMutator {
174174
}
175175
return stmt;
176176
}
177+
177178
private:
178179
const FragmentGetter &fragment_getter;
179180
};

0 commit comments

Comments
 (0)