Skip to content

Commit bf88408

Browse files
authored
[CINN] Add the TileBroadcastTactic for NCHW broadcast (#70092)
1 parent bd79716 commit bf88408

File tree

11 files changed

+493
-38
lines changed

11 files changed

+493
-38
lines changed

paddle/cinn/ir/group_schedule/config/group_tile_config.cc

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,22 +150,17 @@ std::shared_ptr<ScheduleConfig::BaseInfo> InitBasicInfo(
150150
const std::shared_ptr<FusionGroupInfo>& group_info) {
151151
std::shared_ptr<ScheduleConfig::BaseInfo> base_info =
152152
std::make_shared<ScheduleConfig::BaseInfo>();
153-
base_info->data_rank = group_info->loop_ranges.size();
153+
base_info->reduce_axis = group_info->reduce_axis;
154+
base_info->loop_ranges = group_info->loop_ranges;
154155
base_info->loop_strides = group_info->loop_strides;
155156
base_info->can_apply_grid_reduce = group_info->can_apply_grid_reduce;
156157

157-
std::set<int64_t> reduce_dim_loc;
158-
for (int64_t dim : group_info->reduce_axis) {
159-
if (dim < 0) {
160-
dim += base_info->data_rank;
161-
}
162-
base_info->reduce_axis.push_back(dim);
163-
reduce_dim_loc.insert(dim);
164-
}
158+
std::set<int64_t> reduce_dim_loc(group_info->reduce_axis.begin(),
159+
group_info->reduce_axis.end());
165160

166161
base_info->spatial_numel = 1;
167162
base_info->reduce_numel = 1;
168-
for (int64_t i = 0; i < base_info->data_rank; ++i) {
163+
for (int64_t i = 0; i < base_info->loop_ranges.size(); ++i) {
169164
if (reduce_dim_loc.count(i)) {
170165
if (group_info->loop_ranges[i] == -1)
171166
base_info->has_dynamic_reduce = true;

paddle/cinn/ir/group_schedule/config/group_tile_config.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ using IterSpaceType = std::vector<std::pair<std::string, std::string>>;
3333
struct ScheduleConfig {
3434
struct BaseInfo {
3535
std::vector<int64_t> reduce_axis;
36+
std::vector<int64_t> loop_ranges;
3637
std::vector<int64_t> loop_strides;
37-
int64_t data_rank;
3838
int64_t reduce_numel;
3939
int64_t spatial_numel;
4040
bool has_dynamic_spatial{false};

paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "paddle/cinn/ir/group_schedule/config/schedule_config_manager.h"
1818
#include "paddle/cinn/ir/group_schedule/tactic/compute_at_reduction_tactic.h"
1919
#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"
20+
#include "paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.h"
2021
#include "paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.h"
2122
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
2223
#include "paddle/cinn/ir/op/ir_operators.h"
@@ -31,12 +32,10 @@ void DynamicShapeGroupScheduler::Init() {
3132
VLOG(4) << "original group func body: \n"
3233
<< ir_sch_->GetModule().GetExprs()[0];
3334
InitBuckets();
35+
tactics_.emplace_back(CreateTileBroadcastTactic());
3436
tactics_.emplace_back(CreateTileFirstGeneralTactic());
35-
VLOG(4) << "CreateTileFirstGeneralTactic End";
3637
tactics_.emplace_back(CreateComputeInlineTactic());
37-
VLOG(4) << "CreateTileCreateComputeInlineTactic End";
3838
tactics_.emplace_back(CreateComputeAtReductionTactic());
39-
VLOG(4) << "CreateComputeAtReductionTactic End";
4039
}
4140

4241
void DynamicShapeGroupScheduler::InitBuckets() {
@@ -156,7 +155,8 @@ void DynamicShapeGroupScheduler::ApplyTactics(BucketContext* bucket_context) {
156155
<< "] on ScheduleBlockNode [" << node->id() << "] func body:\n"
157156
<< bucket_context->ir_sch->GetModule().GetExprs().front();
158157
};
159-
tactic->Init(&(bucket_context->schedule_context));
158+
tactic->Init(&(bucket_context->schedule_context),
159+
bucket_context->ir_sch.get());
160160
bucket_context->schedule_block_graph->DFSTopoWalk(ApplyTacticFunc);
161161
bucket_context->schedule_block_graph->Update(*(bucket_context->ir_sch));
162162
VLOG(5) << "[End " << tactic->TacticName() << "] func body: "

paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ gather_srcs(cinnapi_src SRCS optimize_reduction_tactic.cc)
77
gather_srcs(cinnapi_src SRCS compute_at_reduction_tactic.cc)
88
gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc)
99
gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
10+
gather_srcs(cinnapi_src SRCS tile_broadcast_tactic.cc)
1011
gather_srcs(cinnapi_src SRCS tile_first_general_tactic.cc)

paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,17 @@ struct ScheduleContext {
9191

9292
class ScheduleTactic {
9393
public:
94-
virtual void Init(ScheduleContext* context) = 0;
94+
// Attribute key to record which tile tactic has been applied on a graph.
95+
// Exactly one tile tactic is applied on a graph during scheduling.
96+
static constexpr char* kTileMethod = "tile_method";
97+
98+
virtual void Init(ScheduleContext* context) {
99+
PADDLE_THROW(::common::errors::Unimplemented(
100+
"ScheduleTactic subclass must implement one of the Init method."));
101+
}
102+
virtual void Init(ScheduleContext* context, ir::IRSchedule* sch) {
103+
Init(context);
104+
}
95105

96106
virtual void Apply(ir::IRSchedule* sch, const std::string& block_id) = 0;
97107

0 commit comments

Comments
 (0)