Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ TEST(AutoInline, SingleLoopInline) {
nullptr,
target,
true);

VLOG(6) << "Expr after lowering:";
VLOG(6) << funcs[0]->body;

Expand Down Expand Up @@ -170,7 +169,9 @@ TEST(AutoInline, AddReluInline) {

EXPECT_EQ(graph->fusion_groups.size(), 1UL);
std::vector<ir::LoweredFunc> funcs =
op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]);
op_lowerer->Lower(graph->fusion_groups[0],
/*apply_op_schedule = */ false,
/*apply_group_schedule=*/false);

VLOG(6) << "Expr before auto inline: " << funcs[0]->body;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) {

TEST_F(TestMultiLevelTiling, Pool2d) {
default_input_names = {"input"};
default_output_names = {"var_0"};
std::vector<int32_t> input_shape{2, 8, 16, 16};
std::vector<int32_t> output_shape{2, 8, 8, 8};
default_output_names = {"var_0", "pad_temp_0"};
std::vector<std::vector<int32_t>> input_shapes{{2, 8, 16, 16}};
std::vector<std::vector<int32_t>> output_shapes{{2, 8, 8, 8}, {2, 8, 18, 18}};
std::string pooling_type = "max";
std::vector<int> ksize{3, 3};
std::vector<int> strides{2, 2};
Expand All @@ -402,7 +402,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
bool adaptive = false;
std::string padding_algorithm = "EXPLICIT";
frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build(
{{"input", input_shape}},
{{"input", input_shapes[0]}},
{{"pool_type", pooling_type},
{"kernel_size", ksize},
{"stride_size", strides},
Expand Down Expand Up @@ -440,107 +440,104 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
ScheduleBlock(root)
{
serial for (i, 0, 2)
{
serial for (j, 0, 8)
serial for (i, 0, 2)
{
serial for (k, 0, 18)
serial for (j, 0, 8)
{
serial for (a, 0, 18)
serial for (k, 0, 18)
{
ScheduleBlock(pad_temp_0)
serial for (a, 0, 18)
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
ScheduleBlock(pad_temp_0)
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
}
}
}
}
}
}
}
}
}
} // end Expr 0
Expr 1 {
{
ScheduleBlock(root_0)
{
{
thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{
thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4)
thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{
serial for (i_1, 0, 1)
thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4)
{
serial for (j_1, 0, 4)
serial for (i_1, 0, 1)
{
serial for (k_1, 0, 1)
serial for (j_1, 0, 4)
{
serial for (a_1, 0, 4)
serial for (k_1, 0, 1)
{
ScheduleBlock(var_0__reduce_init)
serial for (a_1, 0, 4)
{
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
ScheduleBlock(var_0__reduce_init)
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
}
}
}
}
}
}
}
{
serial for (kernel_idx, 0, 3)
{
serial for (kernel_idx_0, 0, 3)
serial for (kernel_idx, 0, 3)
{
serial for (ax0_ax1_ax2_ax3_fused, 0, 28)
serial for (kernel_idx_0, 0, 3)
{
ScheduleBlock(pad_temp_0_shared_temp_buffer)
serial for (ax0_ax1_ax2_ax3_fused, 0, 28)
{
v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0)))
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
ScheduleBlock(pad_temp_0_shared_temp_buffer)
{
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3]
v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0)))
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
{
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3]
}
}
}
}
serial for (i_1, 0, 1)
{
serial for (j_1, 0, 4)
serial for (i_1, 0, 1)
{
serial for (k_1, 0, 1)
serial for (j_1, 0, 4)
{
serial for (a_1, 0, 4)
serial for (k_1, 0, 1)
{
ScheduleBlock(var_0_local_temp_buffer)
serial for (a_1, 0, 4)
{
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
ScheduleBlock(var_0_local_temp_buffer)
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
}
}
}
}
}
}
}
}
}
serial for (ax0_0, 0, 1)
{
serial for (ax1_0, 0, 4)
serial for (ax0_0, 0, 1)
{
serial for (ax2_0, 0, 1)
serial for (ax1_0, 0, 4)
{
serial for (ax3_0, 0, 4)
serial for (ax2_0, 0, 1)
{
ScheduleBlock(var_0)
serial for (ax3_0, 0, 4)
{
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
ScheduleBlock(var_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
}
}
}
}
Expand All @@ -553,7 +550,7 @@ Expr 1 {
}
}
}
} // end Expr 1
} // end Expr 0
)ROC";
ASSERT_EQ(ir, expected_ir);

Expand All @@ -569,8 +566,8 @@ Expr 1 {
pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names,
default_output_names,
{input_shape},
{output_shape},
input_shapes,
output_shapes,
target_);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,10 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(
absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_);

if (apply_manual_schedule) {
lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front());
} else {
lowered_funcs_ =
op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front());
}
lowered_funcs_ =
op_lowerer.Lower(graph->fusion_groups.front(),
/*apply_op_schedule = */ apply_manual_schedule,
/*apply_group_schedule = */ apply_manual_schedule);
CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty";

std::vector<Expr> bodys;
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/auto_schedule/task/tune_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ void TuneTask::Initialize(
op_lowerer = lower_handler;

// Set lowered_funcs and analyze output names.
this->lowered_funcs = op_lowerer->LowerWithoutSchedule(subgraph);
this->lowered_funcs = op_lowerer->Lower(
subgraph, /*apply_op_schedule = */ false, /*apply_group_schedule=*/false);
this->output_names = GetOutputNamesFromLoweredFunc(this->lowered_funcs);
this->serialized_key = SerializeToString(shape_dict, dtype_dict);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ class PerformanceTester : public ::testing::Test {

for (auto group : graph->fusion_groups) {
compile_options.lowered_funcs.push_back(
op_lowerer->LowerWithoutSchedule(group));
op_lowerer->Lower(group,
/*apply_op_schedule = */ false,
/*apply_group_schedule=*/false));
}

VLOG(3) << "===========================No Schedule LoweredFunc "
Expand Down
Loading