Skip to content

Commit

Permalink
trace s_copy_ (pytorch#15690)
Browse files Browse the repository at this point in the history
Summary:
s_copy_ was previously special-cased for out of place tracing.
This adds support for inplace tracing, which fixes tracing of
inception_v3

Fixes pytorch#15216
Pull Request resolved: pytorch#15690

Differential Revision: D13572011

Pulled By: zdevito

fbshipit-source-id: 1d565dec039a4b8c59179254285e61d2517ef9a9
  • Loading branch information
zdevito authored and facebook-github-bot committed Jan 3, 2019
1 parent 78442f0 commit d42e909
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 17 deletions.
2 changes: 1 addition & 1 deletion test/expect/TestJit.test_inplace_copy.expect
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ graph(%0 : Double(4, 4)) {
%11 : int = prim::Constant[value=0]()
%12 : Device = prim::Constant[value="cpu"]()
%13 : Double(4, 4) = aten::zeros(%9, %10, %11, %12)
%14 : Double(4, 4) = aten::expand_as(%0, %13)
%14 : Double(4, 4) = aten::copy_(%13, %0)
return (%14);
}
12 changes: 9 additions & 3 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,14 @@ def fn(x):
if RUN_CUDA_MULTI_GPU:
run(device="cuda:1")

def test_trace_indexed_assignment(self):
def stuff(x, y):
x = x.clone()
x[0] = y
return x
example = torch.rand(3, 4)
self.checkTrace(stuff, (example, example[0] + 1))

# TODO: implement
@unittest.expectedFailure
def test_output_unflatten(self):
Expand Down Expand Up @@ -8098,9 +8106,7 @@ def foo(x):
x[i, :] = torch.zeros(4)
return x

self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(3, 4)]),
'Output nr 1. of the traced function does not match the '
'corresponding output of the Python function')
self.checkTrace(foo, (torch.rand(3, 4),))

def test_trace_checker_inplace_on_view(self):
def foo(x):
Expand Down
2 changes: 0 additions & 2 deletions tools/git-pre-commit
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/bin/bash
set -e
echo "Running pre-commit flake8"
FLAKE8_OUT=$(python tools/flake8_hook.py)
if [[ ${FLAKE8_OUT} ]]
Expand All @@ -8,7 +7,6 @@ then
exit 1
fi


if [ $(which clang-tidy) ]
then
echo "Running pre-commit clang-tidy"
Expand Down
28 changes: 18 additions & 10 deletions torch/csrc/autograd/VariableTypeManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,24 @@ void VariableType::set_data(Tensor & self, Tensor new_data) const {
as_variable_ref(self).set_data(new_data);
}
Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
jit::Node* node = nullptr;
jit::Value* output = nullptr;
if(torch::jit::tracer::isTracing()) {
auto& graph = jit::tracer::getTracingState()->graph;
// if you have no views of self, then an in place copy is equivalent to
// making sure we expand src to the same size as self
node = graph->create(jit::aten::expand_as, /*num_outputs=*/0);
jit::tracer::addInputs(node, "src", src);
jit::tracer::addInputs(node, "self", self);
graph->appendNode(node);
jit::tracer::ensureUniqueIfOutOfPlaced("copy_ (possibly due to an assignment)", self);
const jit::tracer::TracingState& state = *jit::tracer::getTracingState();
auto& graph = state.graph;
if (state.force_outplace) {
// if you have no views of self, then an in place copy is equivalent to
// making sure we expand src to the same size as self
jit::Node* node = graph->create(jit::aten::expand_as, /*num_outputs=*/1);
jit::tracer::addInputs(node, "src", src);
jit::tracer::addInputs(node, "self", self);
graph->appendNode(node);
jit::tracer::ensureUniqueIfOutOfPlaced("copy_ (possibly due to an assignment)", self);
output = node->output();
} else {
output = graph->insert(
jit::aten::copy_,
{jit::tracer::getValueTrace(self), jit::tracer::getValueTrace(src)});
}
}
// TODO: once copy is exposed in Declarations.yaml we may be able to bind
// it automatically
Expand All @@ -270,7 +278,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block
increment_version(self);
rebase_history(as_variable_ref( self ), std::move(grad_fn));
if(torch::jit::tracer::isTracing()) {
jit::tracer::addOutput(node, self);
jit::tracer::setOutput(output, self);
}
return self;
}
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/tracer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,10 @@ void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
}

void addOutput(Node* node, const at::Tensor& output) {
Value* value = node->addOutput();
setOutput(node->addOutput(), output);
}

void setOutput(Value* value, const at::Tensor& output) {
if (output.defined()) {
value->inferTypeFrom(output);
setValueTrace(autograd::as_variable_ref(output), value);
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ void addOutput(Node* node, T&&) {
" in the JIT tracer. File a bug report.");
}
TORCH_API void addOutput(Node* node, const at::Tensor& tensor);
TORCH_API void setOutput(Value* value, const at::Tensor& output);
TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list);

TORCH_API autograd::Variable getSizeOf(
Expand Down

0 comments on commit d42e909

Please sign in to comment.