diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h index 0f6a3b73c8db93..ff878f896a74de 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h @@ -208,7 +208,7 @@ static CutlassGemmConfig estimate_best_config_from_occupancies( CutlassGemmConfig best_config; - if (m >= 256 && sm == 86 && + if (m >= 256 && sm == 86 && group_size > 0 && std::find_if( candidate_configs.begin(), candidate_configs.end(), @@ -221,13 +221,20 @@ static CutlassGemmConfig estimate_best_config_from_occupancies( SplitKStyle::NO_SPLIT_K, 1, 2}; - } else if (m >= 256 && sm >= 80 && group_size > 0) { + } else if (m >= 256 && sm == 80 && group_size > 0 && + std::find_if(candidate_configs.begin(), + candidate_configs.end(), + [](const CutlassGemmConfig& gemm_config) { + return gemm_config.tile_config == + CutlassTileConfig:: + CtaShape256x128x64_WarpShape64x64x64; + }) != candidate_configs.end()) { best_config = CutlassGemmConfig{ CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, SplitKStyle::NO_SPLIT_K, 1, 4}; - } else if (m >= 256 && + } else if (m >= 256 && sm == 80 && group_size <= 0 && std::find_if(candidate_configs.begin(), candidate_configs.end(), [](const CutlassGemmConfig& gemm_config) {