Skip to content

Commit 3bf93d7

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
[JIT] Add gradient check in constants (pytorch#64613)
Summary: fixes internal issue Pull Request resolved: pytorch#64613 Reviewed By: Gamrix Differential Revision: D30799016 Pulled By: eellison fbshipit-source-id: 48ef52d1cac627919e6cd232216d24878a2a8b58
1 parent d4b1016 commit 3bf93d7

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

test/cpp/jit/test_misc.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,6 +2685,13 @@ TEST(ComputeFlopsTest, Basic) {
26852685
ASSERT_EQ(flops, 360);
26862686
}
26872687

2688+
TEST(TestConstant, TensorGrad) {
2689+
auto graph = std::make_shared<Graph>();
2690+
IValue ten = torch::randn({3, 5}).requires_grad_(true);
2691+
auto con = tryInsertConstant(*graph, ten);
2692+
ASSERT_TRUE(con == c10::nullopt);
2693+
}
2694+
26882695
TEST(TestMutation, Basic) {
26892696
auto graph = std::make_shared<Graph>();
26902697
std::unordered_map<std::string, Value*> vmap;

torch/csrc/jit/ir/constants.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ namespace torch {
1010
namespace jit {
1111

1212
bool insertableTensor(const at::Tensor& ten) {
13-
return !ten.requires_grad();
13+
// bail if tensor has no storage i.e. opaque tensor used in MKLdnn.
14+
// or gradients because we have no way of serializing them & are mutable
15+
return !ten.requires_grad() && ten.has_storage();
1416
}
1517

1618
bool insertableIValue(const IValue& ivalue) {
@@ -65,8 +67,7 @@ c10::optional<Value*> tryInsertConstant(
6567
Node* n = g.create(prim::Constant);
6668
if (val.isTensor()) {
6769
at::Tensor ref = val.toTensor();
68-
if (!ref.has_storage()) {
69-
// bail if tensor has no storage i.e. opaque tensor used in MKLdnn.
70+
if (!insertableTensor(val.toTensor())) {
7071
n->destroy();
7172
return c10::nullopt;
7273
}

0 commit comments

Comments
 (0)