@@ -389,11 +389,11 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) {
389389 std::string scope = alloc_storage_scope_.at (buffer);
390390 if (scope.find (" wmma." ) == 0 ) {
391391 if (scope == " wmma.matrix_a" || scope == " wmma.matrix_b" ) {
392- CHECK (op->type . is_float () && op->type . bits () == 16 )
393- << " Matrix_a and matrix_b only support half type for now" ;
392+ CHECK (op->type == Float ( 16 ) || op->type == Int ( 8 ) || op-> type == UInt ( 8 ) )
393+ << " Matrix_a and matrix_b only support half or char or unsigned char type for now" ;
394394 } else {
395- CHECK (op->type . is_float () && ( op->type . bits () == 16 || op->type . bits () == 32 ))
396- << " Accumulator only support half and float type for now" ;
395+ CHECK (op->type == Float ( 16 ) || op->type == Float ( 32 ) || op->type == Int ( 32 ))
396+ << " Accumulator only support half, float and int type for now" ;
397397 }
398398 constant_size /= 256 ;
399399 PrintWmmaScope (scope, op->type , buffer, stream);
@@ -511,7 +511,8 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*
511511 PrintConst (op, os, this );
512512}
513513
514- void CodeGenCUDA::PrintWmmaScope (const std::string &scope, Type t, const Variable* variable, std::ostream &os) {
514+ void CodeGenCUDA::PrintWmmaScope (const std::string &scope, Type t,
515+ const Variable* variable, std::ostream &os) {
515516 std::stringstream type;
516517 PrintType (t, type);
517518 std::string shape_str = fragment_shapes[variable];
@@ -527,7 +528,8 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, const Variabl
527528 << shape_str << " , " << type.str () << " , nvcuda::wmma::" << layout_str <<" >" ;
528529 } else if (scope == " wmma.accumulator" ) {
529530 need_mma_h_ = true ;
530- os << " nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << " , " << type.str () << " >" ;
531+ os << " nvcuda::wmma::fragment<nvcuda::wmma::accumulator, "
532+ << shape_str << " , " << type.str () << " >" ;
531533 }
532534}
533535
0 commit comments