Skip to content

Commit

Permalink
fix pool bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Jan 3, 2024
1 parent 1829cbb commit 8e2293e
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions python/aitemplate/backend/rocm/pool2d/pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

INSTANCE_TEMPLATE = jinja2.Template(
"""
using {{name}} = ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
ck::half_t, ck::half_t, float, {{reduce_func}}, false, 64, 64, 1, 4, 1, 4>;
using {{name}} = ck::tensor_operation::device::DevicePool2dFwd_NHWC_NHWC<
ck::half_t, ck::half_t, ck::index_t, float, {{reduce_func}}, false, 64, 64, 1, 4, 1, 4>;
"""
)

Expand All @@ -35,14 +35,17 @@
{{indent}}auto argument_ptr = op.MakeArgumentPointer(static_cast<ck::half_t *>(in_ptr),
{{indent}} static_cast<ck::half_t *>(out_ptr),
{{indent}} nullptr,
{{indent}} *batch,
{{indent}} *in_ch,
{{indent}} input_shape,
{{indent}} kernel_shape,
{{indent}} output_shape,
{{indent}} input_stride,
{{indent}} output_stride,
{{indent}} indices_stride,
{{indent}} conv_filter_strides,
{{indent}} dilations,
{{indent}} input_left_pads,
{{indent}} input_right_pads);
{{indent}} input_right_pads,
{{indent}} {2, 3});
{{indent}}if(!op.IsSupportedArgument(argument_ptr.get())) {
{{indent}} LOG(FATAL) << "wrong! " << op.GetTypeString() << " with the specified compilation parameters does not support this Pool problem.";
{{indent}}}
Expand Down Expand Up @@ -87,19 +90,23 @@
) {
{{shape_function}}
const std::array<ck::index_t, 2> conv_filter_strides{static_cast<ck::index_t>(stride),
const std::vector<ck::index_t> conv_filter_strides{static_cast<ck::index_t>(stride),
static_cast<ck::index_t>(stride)};
const std::array<ck::index_t, 2> input_left_pads{static_cast<ck::index_t>(pad),
const std::vector<ck::index_t> input_left_pads{static_cast<ck::index_t>(pad),
static_cast<ck::index_t>(pad)};
const std::array<ck::index_t, 2> input_right_pads{static_cast<ck::index_t>(pad),
const std::vector<ck::index_t> input_right_pads{static_cast<ck::index_t>(pad),
static_cast<ck::index_t>(pad)};
const std::array<ck::index_t, 2> input_shape{static_cast<ck::index_t>(*in_h),
const std::vector<ck::index_t> input_shape{static_cast<ck::index_t>(*batch), static_cast<ck::index_t>(*in_ch), static_cast<ck::index_t>(*in_h),
static_cast<ck::index_t>(*in_w)};
const std::array<ck::index_t, 2> kernel_shape{static_cast<ck::index_t>(kernel_h),
static_cast<ck::index_t>(kernel_w)};
const std::array<ck::index_t, 2> output_shape{static_cast<ck::index_t>(*out_h),
const std::vector<ck::index_t> kernel_shape{static_cast<ck::index_t>(kernel_h), static_cast<ck::index_t>(kernel_w)};
const std::vector<ck::index_t> output_shape{static_cast<ck::index_t>(*batch), static_cast<ck::index_t>(*in_ch), static_cast<ck::index_t>(*out_h),
static_cast<ck::index_t>(*out_w)};
const std::vector<ck::index_t> input_stride{static_cast<ck::index_t>(CI*HI*WI), 1, static_cast<ck::index_t>(WI*CI), static_cast<ck::index_t>(CI)};
const std::vector<ck::index_t> output_stride{static_cast<ck::index_t>(CI*HO*WO), 1, static_cast<ck::index_t>(WO*CI), static_cast<ck::index_t>(CI)};
const std::vector<ck::index_t> indices_stride{static_cast<ck::index_t>(CI*HO*WO), 1, static_cast<ck::index_t>(WO*CI), static_cast<ck::index_t>(CI)};
const std::vector<ck::index_t> dilations{1, 1};
{{exec_paths}}
throw std::runtime_error(
Expand Down

0 comments on commit 8e2293e

Please sign in to comment.