Skip to content

Commit 7fd0bd7

Browse files
committed
Update on "Add OpInfo test to check that floating point inputs in OpInfos have requires_grad set to True"
This test detected a number of sampling methods that were not generating the samples as expected, e.g. `index_put`, `cosine_embedding`, `stft`, but perhaps most notably the generator for `BinOps`. It also detected that `reminder` and `fmod` did not have implemented the backward formula for the second input. I added this in the previous PR. [ghstack-poisoned]
2 parents bebda4a + c13e6bf commit 7fd0bd7

File tree

5 files changed

+42
-27
lines changed

5 files changed

+42
-27
lines changed

c10/macros/Macros.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -484,12 +484,13 @@ __host__ __device__
484484
#endif // HAS_DEMANGLE
485485

486486
#ifdef __clang__
487-
#define _C10_PRAGMA__(string) _Pragma( #string )
488-
#define _C10_PRAGMA_(string) _C10_PRAGMA__( string )
487+
#define _C10_PRAGMA__(string) _Pragma(#string)
488+
#define _C10_PRAGMA_(string) _C10_PRAGMA__(string)
489489
#define C10_CLANG_DIAGNOSTIC_PUSH() _Pragma("clang diagnostic push")
490490
#define C10_CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop")
491-
#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) _C10_PRAGMA_(clang diagnostic ignored flag)
492-
#define C10_CLANG_HAS_WARNING(flag) __has_warning( flag )
491+
#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) \
492+
_C10_PRAGMA_(clang diagnostic ignored flag)
493+
#define C10_CLANG_HAS_WARNING(flag) __has_warning(flag)
493494
#else
494495
#define C10_CLANG_DIAGNOSTIC_PUSH()
495496
#define C10_CLANG_DIAGNOSTIC_POP()

test/cpp/jit/test_misc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2870,7 +2870,7 @@ TEST_F(Composed, ComposedOp) {
28702870
bool fusable_on_device = torch::jit::tensorexpr::getTEMustUseLLVMOnCPU();
28712871
torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
28722872
setTensorExprDynamicShapeFusionEnabled(true);
2873-
FuseTensorExprs(graph, /*min_group_size*/2, /*add_composed_op*/true);
2873+
FuseTensorExprs(graph, /*min_group_size*/ 2, /*add_composed_op*/ true);
28742874
Code code(graph, "");
28752875
InterpreterState interpreter{code};
28762876
std::vector<IValue> stack = {a, b};

test/cpp/tensorexpr/test_dynamic_shapes.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,6 @@ TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) {
318318
symbolic_strides[y_inp] = input_desc;
319319
symbolic_strides[graph->outputs().at(0)] = input_desc;
320320

321-
322321
TensorExprKernel kernel(
323322
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
324323

@@ -443,7 +442,6 @@ TEST(DynamicShapes, GraphWithCatAndBroadcast) {
443442
symbolic_strides[z_inp] = input_desc;
444443
symbolic_strides[graph->outputs().at(0)] = input_desc;
445444

446-
447445
TensorExprKernel kernel(
448446
graph, {}, symbolic_shape_inputs, false, symbolic_strides);
449447

torch/csrc/jit/passes/tensorexpr_fuser.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
22

3+
#include <ATen/core/interned_strings.h>
34
#include <ATen/core/symbol.h>
45
#include <ATen/record_function.h>
56
#include <c10/util/FunctionRef.h>
@@ -22,7 +23,6 @@
2223
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
2324
#include <torch/csrc/jit/tensorexpr/kernel.h>
2425
#include <torch/csrc/utils/memory.h>
25-
#include <ATen/core/interned_strings.h>
2626

2727
// NOLINTNEXTLINE
2828
C10_DEFINE_bool(
@@ -1284,9 +1284,11 @@ Operation createTensorExprOp(const Node* node) {
12841284
stride_map[v] = striding_inputs[index];
12851285
index++;
12861286
}
1287-
std::vector<std::string> output_desc = node->ival(attr::striding_outputs_desc).to<std::vector<std::string>>();
1287+
std::vector<std::string> output_desc =
1288+
node->ival(attr::striding_outputs_desc).to<std::vector<std::string>>();
12881289
for (size_t i = 0; i < subgraph->outputs().size(); ++i) {
1289-
stride_map[subgraph->outputs().at(i)] = {strideInputFromString(output_desc.at(i))};
1290+
stride_map[subgraph->outputs().at(i)] = {
1291+
strideInputFromString(output_desc.at(i))};
12901292
}
12911293

12921294
std::shared_ptr<tensorexpr::TensorExprKernel> kernel =

torch/csrc/jit/tensorexpr/kernel.cpp

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,7 +1147,8 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(
11471147
ExprHandle axis = axes[i];
11481148
absolute_position = absolute_position + (stride * axis);
11491149
}
1150-
std::vector<ExprHandle> new_axes(sorted_stride_indices_descending.size());
1150+
std::vector<ExprHandle> new_axes(
1151+
sorted_stride_indices_descending.size());
11511152
for (size_t stride_index : sorted_stride_indices_descending) {
11521153
auto size = sizes[stride_index];
11531154
auto stride = strides[stride_index];
@@ -1156,25 +1157,31 @@ Tensor TensorExprKernel::convertOutputToCorrectStrides(
11561157
// if the size is one, we don't advance the absolute position
11571158
// which would give 0
11581159
auto non_one_position = absolute_position % ExprHandle(stride);
1159-
absolute_position = CompareSelect::make(size, one, absolute_position, non_one_position, kEQ);
1160+
absolute_position = CompareSelect::make(
1161+
size, one, absolute_position, non_one_position, kEQ);
11601162
new_axes[stride_index] = index;
11611163
}
11621164
return BufHandle(buf).load(new_axes);
11631165
});
11641166
}
11651167

1166-
Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(torch::jit::Value* v) {
1168+
Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
1169+
torch::jit::Value* v) {
11671170
const TensorTypePtr& tt = v->type()->expect<TensorType>();
11681171
TORCH_INTERNAL_ASSERT(
11691172
bufs_.count(v),
11701173
buildErrorMessage(
11711174
"Ouput tensor has no corresponding bufs in the fuser."));
11721175
BufPtr buf = bufs_.at(v);
11731176
// output is contiguous, no work to do
1174-
if (tensorOutputStrideDesc_[v->offset()] == torch::jit::StrideInput::TENSOR_CONT) {
1175-
return Tensor(buf, nullptr);;
1177+
if (tensorOutputStrideDesc_[v->offset()] ==
1178+
torch::jit::StrideInput::TENSOR_CONT) {
1179+
return Tensor(buf, nullptr);
1180+
;
11761181
}
1177-
TORCH_INTERNAL_ASSERT(tensorOutputStrideDesc_[v->offset()] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
1182+
TORCH_INTERNAL_ASSERT(
1183+
tensorOutputStrideDesc_[v->offset()] ==
1184+
torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
11781185
auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
11791186
auto dims = c10::fmap<DimArg>(sizes);
11801187
auto strides = make_channels_last_strides(sizes);
@@ -1185,11 +1192,12 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(torch::jit::Value
11851192
auto zero = LongImm::make(0);
11861193
std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
11871194
// See explanation in convertOutputToCorrectStrides
1188-
return convertOutputToCorrectStrides(sizes, sorted_stride_indices, strides, buf);
1195+
return convertOutputToCorrectStrides(
1196+
sizes, sorted_stride_indices, strides, buf);
11891197
}
11901198

1191-
1192-
Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides(torch::jit::Value* v) {
1199+
Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides(
1200+
torch::jit::Value* v) {
11931201
const TensorTypePtr& tt = v->type()->expect<TensorType>();
11941202
TORCH_INTERNAL_ASSERT(
11951203
bufs_.count(v),
@@ -1231,9 +1239,9 @@ Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides(torch::jit::Va
12311239
auto zero = LongImm::make(0);
12321240
std::vector<size_t> sorted_stride_indices = reverse_sort_indices(strides);
12331241

1234-
// TODO: call into `convertOutputToCorrectStrides`. Currently this causes a bug
1235-
// in IRSimplifier to occur.
1236-
// See explanation in `convertOutputToCorrectStrides`
1242+
// TODO: call into `convertOutputToCorrectStrides`. Currently this causes a
1243+
// bug in IRSimplifier to occur. See explanation in
1244+
// `convertOutputToCorrectStrides`
12371245
return Compute(
12381246
"output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
12391247
std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
@@ -1467,7 +1475,8 @@ void TensorExprKernel::compile() {
14671475
auto stride_desc = symbolic_strides_[output];
14681476
TORCH_INTERNAL_ASSERT(stride_desc.size() == 1);
14691477
tensorOutputStrideDesc_.push_back(stride_desc[0]);
1470-
Tensor properly_strided_output = convertSymbolicOutputToCorrectStrides(output);
1478+
Tensor properly_strided_output =
1479+
convertSymbolicOutputToCorrectStrides(output);
14711480
if (properly_strided_output.stmt()) {
14721481
block->append_stmt(properly_strided_output.stmt());
14731482
}
@@ -1476,7 +1485,8 @@ void TensorExprKernel::compile() {
14761485
// The "strided" tensor will be incorrect if used in NNC,
14771486
// since NNC views it as contiguous. Only convert it to the right
14781487
// strides at the end of the kernel (if already contiguous it's a no-op)
1479-
Tensor properly_strided_output = convertStaticShapeOutputToCorrectStrides(output);
1488+
Tensor properly_strided_output =
1489+
convertStaticShapeOutputToCorrectStrides(output);
14801490
if (properly_strided_output.stmt()) {
14811491
block->append_stmt(properly_strided_output.stmt());
14821492
}
@@ -1601,9 +1611,13 @@ void TensorExprKernel::updateOutputSizesAndStrides(
16011611
}
16021612

16031613
if (tensorOutputStrideDesc_[i] == torch::jit::StrideInput::TENSOR_CONT) {
1604-
tensorOutputStrides_[i] = TensorType::contiguousStridesOf(tensorOutputSizes_[i]);
1605-
} else if (tensorOutputStrideDesc_[i] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST) {
1606-
tensorOutputStrides_[i] = at::get_channels_last_strides_2d(tensorOutputSizes_[i]);
1614+
tensorOutputStrides_[i] =
1615+
TensorType::contiguousStridesOf(tensorOutputSizes_[i]);
1616+
} else if (
1617+
tensorOutputStrideDesc_[i] ==
1618+
torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST) {
1619+
tensorOutputStrides_[i] =
1620+
at::get_channels_last_strides_2d(tensorOutputSizes_[i]);
16071621
} else {
16081622
std::string output_desc = toString(tensorOutputStrideDesc_[i]);
16091623
TORCH_INTERNAL_ASSERT(

0 commit comments

Comments
 (0)