Skip to content

Commit

Permalink
add set_grad_enabled to TorchScript and fix data attribute
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#25350

Test Plan: Imported from OSS

Differential Revision: D17100829

fbshipit-source-id: d85d6f3b03218b9c77e144365940eeaa5b4cce9a
  • Loading branch information
Wanchao Liang authored and facebook-github-bot committed Sep 10, 2019
1 parent 387d5a4 commit a7eaec6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
14 changes: 12 additions & 2 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3861,12 +3861,22 @@ def f_grad(x):
self.checkScript(f_grad, (y,))

def test_tensor_data(self):
x = torch.randn(3, 4)
x = torch.randn(3, 4, requires_grad=True)
y = torch.randn(4, 5)

def f_data(x):
return x.data

self.checkScript(f_data, (x,))
scripted_f_data = torch.jit.script(f_data)

scripted_x = scripted_f_data(x)
self.assertEqual(scripted_x, f_data(x))
self.assertEqual(scripted_x.requires_grad, False)

scripted_y = scripted_f_data(y)
self.assertEqual(scripted_y, f_data(y))
self.assertEqual(scripted_x.requires_grad, False)


def test_tensor_dtype(self):
x_byte = torch.empty(34, 56, 78, dtype=torch.uint8)
Expand Down
9 changes: 7 additions & 2 deletions torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,13 @@ RegisterOperators reg(
},
aliasAnalysisFromSchema()),
Operator(
"prim::data(Tensor(b) a) -> Tensor(b)",
noop,
"prim::data(Tensor(a) a) -> Tensor(a)",
[](Stack& stack) {
at::Tensor a;
pop(stack, a);
push(stack, autograd::Variable(a).variable_data());
return 0;
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_cuda(Tensor a) -> bool",
Expand Down

0 comments on commit a7eaec6

Please sign in to comment.