@@ -1147,7 +1147,8 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(
11471147 ExprHandle axis = axes[i];
11481148 absolute_position = absolute_position + (stride * axis);
11491149 }
1150- std::vector<ExprHandle> new_axes (sorted_stride_indices_descending.size ());
1150+ std::vector<ExprHandle> new_axes (
1151+ sorted_stride_indices_descending.size ());
11511152 for (size_t stride_index : sorted_stride_indices_descending) {
11521153 auto size = sizes[stride_index];
11531154 auto stride = strides[stride_index];
@@ -1156,25 +1157,31 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(
11561157 // if the size is one, we don't advance the absolute position
11571158 // which would give 0
11581159 auto non_one_position = absolute_position % ExprHandle (stride);
1159- absolute_position = CompareSelect::make (size, one, absolute_position, non_one_position, kEQ );
1160+ absolute_position = CompareSelect::make (
1161+ size, one, absolute_position, non_one_position, kEQ );
11601162 new_axes[stride_index] = index;
11611163 }
11621164 return BufHandle (buf).load (new_axes);
11631165 });
11641166}
11651167
1166- Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides (torch::jit::Value* v) {
1168+ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides (
1169+ torch::jit::Value* v) {
11671170 const TensorTypePtr& tt = v->type ()->expect <TensorType>();
11681171 TORCH_INTERNAL_ASSERT (
11691172 bufs_.count (v),
11701173 buildErrorMessage (
11711174 " Ouput tensor has no corresponding bufs in the fuser." ));
11721175 BufPtr buf = bufs_.at (v);
11731176 // output is contiguous, no work to do
1174- if (tensorOutputStrideDesc_[v->offset ()] == torch::jit::StrideInput::TENSOR_CONT) {
1175- return Tensor (buf, nullptr );;
1177+ if (tensorOutputStrideDesc_[v->offset ()] ==
1178+ torch::jit::StrideInput::TENSOR_CONT) {
1179+ return Tensor (buf, nullptr );
1180+ ;
11761181 }
1177- TORCH_INTERNAL_ASSERT (tensorOutputStrideDesc_[v->offset ()] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
1182+ TORCH_INTERNAL_ASSERT (
1183+ tensorOutputStrideDesc_[v->offset ()] ==
1184+ torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
11781185 auto sizes = sizesFromSymbolicShape (tt->symbolic_sizes ());
11791186 auto dims = c10::fmap<DimArg>(sizes);
11801187 auto strides = make_channels_last_strides (sizes);
@@ -1185,11 +1192,12 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(torch::jit::Value
11851192 auto zero = LongImm::make (0 );
11861193 std::vector<ExprPtr> default_strides = make_contiguous_strides (sizes);
11871194 // See explanation in convertOutputToCorrectStrides
1188- return convertOutputToCorrectStrides (sizes, sorted_stride_indices, strides, buf);
1195+ return convertOutputToCorrectStrides (
1196+ sizes, sorted_stride_indices, strides, buf);
11891197}
11901198
1191-
1192- Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides ( torch::jit::Value* v) {
1199+ Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides (
1200+ torch::jit::Value* v) {
11931201 const TensorTypePtr& tt = v->type ()->expect <TensorType>();
11941202 TORCH_INTERNAL_ASSERT (
11951203 bufs_.count (v),
@@ -1231,9 +1239,9 @@ Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides(torch::jit::Va
12311239 auto zero = LongImm::make (0 );
12321240 std::vector<size_t > sorted_stride_indices = reverse_sort_indices (strides);
12331241
1234- // TODO: call into `convertOutputToCorrectStrides`. Currently this causes a bug
1235- // in IRSimplifier to occur.
1236- // See explanation in `convertOutputToCorrectStrides`
1242+ // TODO: call into `convertOutputToCorrectStrides`. Currently this causes a
1243+ // bug in IRSimplifier to occur. See explanation in
1244+ // `convertOutputToCorrectStrides`
12371245 return Compute (
12381246 " output_1" , dims, [&](const std::vector<VarHandle>& axes_input) {
12391247 std::vector<ExprHandle> axes (axes_input.begin (), axes_input.end ());
@@ -1467,7 +1475,8 @@ void TensorExprKernel::compile() {
14671475 auto stride_desc = symbolic_strides_[output];
14681476 TORCH_INTERNAL_ASSERT (stride_desc.size () == 1 );
14691477 tensorOutputStrideDesc_.push_back (stride_desc[0 ]);
1470- Tensor properly_strided_output = convertSymbolicOutputToCorrectStrides (output);
1478+ Tensor properly_strided_output =
1479+ convertSymbolicOutputToCorrectStrides (output);
14711480 if (properly_strided_output.stmt ()) {
14721481 block->append_stmt (properly_strided_output.stmt ());
14731482 }
@@ -1476,7 +1485,8 @@ void TensorExprKernel::compile() {
14761485 // The "strided" tensor will be incorrect if used in NNC,
14771486 // since NNC views it as contiguous. Only convert it to the right
14781487 // strides at the end of the kernel (if already contiguous it's a no-op)
1479- Tensor properly_strided_output = convertStaticShapeOutputToCorrectStrides (output);
1488+ Tensor properly_strided_output =
1489+ convertStaticShapeOutputToCorrectStrides (output);
14801490 if (properly_strided_output.stmt ()) {
14811491 block->append_stmt (properly_strided_output.stmt ());
14821492 }
@@ -1601,9 +1611,13 @@ void TensorExprKernel::updateOutputSizesAndStrides(
16011611 }
16021612
16031613 if (tensorOutputStrideDesc_[i] == torch::jit::StrideInput::TENSOR_CONT) {
1604- tensorOutputStrides_[i] = TensorType::contiguousStridesOf (tensorOutputSizes_[i]);
1605- } else if (tensorOutputStrideDesc_[i] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST) {
1606- tensorOutputStrides_[i] = at::get_channels_last_strides_2d (tensorOutputSizes_[i]);
1614+ tensorOutputStrides_[i] =
1615+ TensorType::contiguousStridesOf (tensorOutputSizes_[i]);
1616+ } else if (
1617+ tensorOutputStrideDesc_[i] ==
1618+ torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST) {
1619+ tensorOutputStrides_[i] =
1620+ at::get_channels_last_strides_2d (tensorOutputSizes_[i]);
16071621 } else {
16081622 std::string output_desc = toString (tensorOutputStrideDesc_[i]);
16091623 TORCH_INTERNAL_ASSERT (
0 commit comments