Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions paddle/cinn/hlir/pe/ir_schedule_pe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ void IRElementwiseSchedule(ir::IRSchedule &ir_sch, // NOLINT
<< ir_sch.GetModule().GetExprs().at(0);
if (target == common::DefaultNVGPUTarget()) {
auto blocks = ir_sch.GetAllBlocks();
ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), true);
std::vector<ir::Expr> loops = ir_sch.GetLoops(blocks[0]);
ir::Expr loop = ir_sch.Fuse(loops);

auto loops = ir_sch.GetLoops(blocks[0]);
auto size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
if (size <= target.max_num_threads()) {
ir_sch.Bind(loops[0], "threadIdx.x");
ir_sch.Bind(loop, "threadIdx.x");
} else {
auto splited = ir_sch.Split(loops[0], {-1, target.max_num_threads()});
auto splited = ir_sch.Split(loop, {-1, target.max_num_threads()});
ir_sch.Bind(splited[0], "blockIdx.x");
ir_sch.Bind(splited[1], "threadIdx.x");
}
Expand All @@ -74,15 +74,15 @@ void IRInjectiveSchedule(ir::IRSchedule &ir_sch, // NOLINT
<< ir_sch.GetModule().GetExprs().at(0);
if (target == common::DefaultNVGPUTarget()) {
auto blocks = ir_sch.GetAllBlocks();
ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), false);
std::vector<ir::Expr> loops = ir_sch.GetLoops(blocks[0]);
ir::Expr loop = ir_sch.Fuse(loops);

auto loops = ir_sch.GetLoops(blocks[0]);
auto size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
if (size <= target.max_num_threads()) {
ir_sch.Bind(loops[0], "threadIdx.x");
ir_sch.Bind(loop, "threadIdx.x");
} else {
auto splited = ir_sch.Split(loops[0], {-1, target.max_num_threads()});
auto splited = ir_sch.Split(loop, {-1, target.max_num_threads()});
ir_sch.Bind(splited[0], "blockIdx.x");
ir_sch.Bind(splited[1], "threadIdx.x");
}
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ set_tests_properties(
PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_nn_grad PROPERTIES TIMEOUT 120)
set_tests_properties(test_program_prune_backward PROPERTIES TIMEOUT 120)
set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 1000)
set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 250)
set_tests_properties(test_imperative_optimizer_v2 PROPERTIES TIMEOUT 250)
set_tests_properties(test_pool2d_op PROPERTIES TIMEOUT 120)
Expand Down