@@ -41,10 +41,18 @@ static std::string GetFP8Type(DataType type) {
4141    stream << " fp8_e4" " _t" 
4242  } else  if  (type.code () == DataType::kFloat8_e4m3fnuz ) {
4343    stream << " fp8_e4" " _t" 
44+   } else  if  (type.code () == DataType::kFloat8_e4m3 ) {
45+     stream << " fp8_e4" " _t" 
46+   } else  if  (type.code () == DataType::kFloat8_e4m3b11fnuz ) {
47+     stream << " fp8_e4" " _t" 
4448  } else  if  (type.code () == DataType::kFloat8_e5m2 ) {
4549    stream << " fp8_e5" " _t" 
50+   } else  if  (type.code () == DataType::kFloat8_e5m2fnuz ) {
51+     stream << " fp8_e5" " _t" 
52+   } else  if  (type.code () == DataType::kFloat8_e8m0fnu ) {
53+     stream << " fp8_e8" " _t" 
4654  } else  {
47-     LOG (FATAL) << " Unsupported FP8 type in HIP codegen" 
55+     LOG (FATAL) << " Unsupported FP8 type in HIP codegen:  "  << type ;
4856  }
4957  return  stream.str ();
5058}
@@ -926,10 +934,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
926934        {" float8_e4m3fnuzx8" " long" 
927935        {" float32x16" " float32x16" 
928936    std::string call_mfma_code = R"( {
929-     *((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}), 
930-                   *((({B_dtype}*){b_ref}) + {b_bias}), 
931-                   *((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0); 
932-   })"  ;
937+        *((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}), 
938+                      *((({B_dtype}*){b_ref}) + {b_bias}), 
939+                      *((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0); 
940+      })"  ;
933941    std::string mfma_buildin = " __builtin_amdgcn_mfma_" 
934942    Replacer replacer;
935943
@@ -955,6 +963,13 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
955963                          op->args , true , os);
956964  } else  if  (op->op .same_as (tl::tl_gemm_sp ())) {
957965    LOG (FATAL) << " tl_gemm_sp is not supported on HIP" 
966+   } else  if  (op->op .same_as (tl::loop_break ())) {
967+     this ->PrintIndent ();
968+     this ->stream  << " break;\n " 
969+   } else  if  (op->op .same_as (tl::no_set_max_nreg ())) {
970+     //  HIP doesn't need explicit register management like CUDA
971+     //  This is a no-op for HIP
972+     return ;
958973  } else  {
959974    CodeGenC::VisitExpr_ (op, os);
960975  }
@@ -1160,7 +1175,8 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
11601175    os << " bfloat16_t" 
11611176    os << ' (' value  << ' f' ' )' 
11621177    return ;
1163-   } else  if  (op->dtype .is_float8_e4m3fnuz ()) {
1178+   } else  if  (op->dtype .is_float8_e4m3fnuz () || op->dtype .is_float8_e4m3 () ||
1179+              op->dtype .is_float8_e4m3fn ()) {
11641180    os << " fp8_e4_t" 
11651181    os << ' (' value  << ' f' ' )' 
11661182    return ;
0 commit comments