Skip to content

Commit

Permalink
Interpreter support for CallFunction/CallMethod (pytorch#21562)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#21562
ghimport-source-id: 17e5e18

Reviewed By: suo

Differential Revision: D15729500

Pulled By: zdevito

fbshipit-source-id: efa8a133b617b1498810392a8da6b513ce00b5eb
  • Loading branch information
zdevito authored and facebook-github-bot committed Jun 9, 2019
1 parent ad0c08f commit ea822d9
Show file tree
Hide file tree
Showing 19 changed files with 390 additions and 217 deletions.
31 changes: 24 additions & 7 deletions test/jit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,18 @@ def disableEmitHook(self):
yield None
self.setHooks()

def _isHookExceptionOk(self, e):
se = str(e)
allowed = ("Could not export Python function",
"closures are not exportable")
for a in allowed:
if a in se:
return True
return False

def emitFunctionHook(self, func):
# func has invalid names for export, skip the jitter check
if func.name == "<lambda>" or "aten::" in func.name:
if func.name == "<lambda>" or "aten::" in func.name or _in_first_class_mode:
return
# disable the hook while we parse code, otherwise we will re-enter the hook
with self.disableEmitHook():
Expand All @@ -72,9 +81,7 @@ def emitFunctionHook(self, func):
src2, constants2 = _jit_python_print(func2)
self.assertMultiLineEqual(src, src2)
except RuntimeError as e:
se = str(e)
if "Could not export Python function" not in se and \
"closures are not exportable" not in se:
if not self._isHookExceptionOk(e):
raise

def emitModuleHook(self, module):
Expand Down Expand Up @@ -113,9 +120,7 @@ def copy_structure_and_params(m):
for line in main_module:
main_module_code += line.decode()
except RuntimeError as e:
se = str(e)
if "Could not export Python function" not in se and \
"closures are not exportable" not in se:
if not self._isHookExceptionOk(e):
raise
else:
return
Expand Down Expand Up @@ -428,6 +433,18 @@ def enable_profiling_mode():
yield
torch._C._jit_set_profiling_mode(False)


_in_first_class_mode = False
@contextmanager
def enable_first_class_mode():
global _in_first_class_mode
torch._C._jit_set_first_class_mode(True)
_in_first_class_mode = True
yield
torch._C._jit_set_first_class_mode(False)
_in_first_class_mode = False


# note: not re-entrant, use unnested only
@contextmanager
def disable_autodiff_subgraph_inlining(enabled=True):
Expand Down
28 changes: 22 additions & 6 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName
from jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \
_trace, enable_cpu_fuser_if, enable_profiling_mode
_trace, enable_cpu_fuser_if, enable_profiling_mode, enable_first_class_mode
from common_nn import module_tests, new_module_tests, criterion_tests
from common_methods_invocations import method_tests as autograd_method_tests
from common_methods_invocations import create_input, unpack_variables, \
Expand Down Expand Up @@ -230,11 +230,6 @@ def _sum_of_list(tensorlist):
s += t.sum()
return s

@contextmanager
def enable_first_class_mode():
torch._C._jit_set_first_class_mode(True)
yield
torch._C._jit_set_first_class_mode(False)

# helper function to generate test qparam
def _helper_generate_qparam(script_module, input_data):
Expand Down Expand Up @@ -2991,6 +2986,27 @@ def forward(self, input):
foo.forward(input)
self.assertEqual(input, foo.foo)

def test_first_class_calls(self):
with enable_first_class_mode():
@torch.jit.script
class Foo(object):
def __init__(self, x):
self.bar = x

def stuff(self, x):
return self.bar + x

@torch.jit.script
def foo(x):
return x * x + Foo(x).stuff(2 * x)

@torch.jit.script
def bar(x):
return foo(x) * foo(x)

x = torch.rand(3, 4)
self.assertEqual(bar(x), (x * x + 3 * x) * (x * x + 3 * x))

def test_invalid_prefix_annotation(self):
with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
with self.capture_stdout() as captured:
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/argument_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input)
// consume object
const IValue* iv = stack[stack_top]++;
AT_ASSERT(iv->isObject());
iv->toObject();
// see [argspec refcounting]
auto p = *reinterpret_cast<const at::ivalue::Object* const*>(iv);
auto obj_ptr = &p->slots()[0];
Expand Down
31 changes: 31 additions & 0 deletions torch/csrc/jit/exception_message.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once
#include <c10/util/Exception.h>
#include <stdexcept>

namespace torch {
namespace jit {

struct ExceptionMessage {
ExceptionMessage(const std::exception& e) : e_(e) {}

private:
const std::exception& e_;
friend std::ostream& operator<<(
std::ostream& out,
const ExceptionMessage& msg);
};

inline std::ostream& operator<<(
std::ostream& out,
const ExceptionMessage& msg) {
auto c10_error = dynamic_cast<const c10::Error*>(&msg.e_);
if (c10_error) {
out << c10_error->msg_without_backtrace();
} else {
out << msg.e_.what();
}
return out;
}

} // namespace jit
} // namespace torch
4 changes: 4 additions & 0 deletions torch/csrc/jit/graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,10 @@ void GraphExecutor::run(Stack& inputs) {
return pImpl->run(inputs);
}

ExecutionPlan GraphExecutor::getPlanFor(Stack& inputs) {
return pImpl->getPlanFor(inputs);
}

std::shared_ptr<Graph> GraphExecutor::graph() const {
return pImpl->graph;
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct TORCH_API GraphExecutor {
GraphExecutor() = default;
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
void run(Stack& inputs);
ExecutionPlan getPlanFor(Stack& inputs);
explicit operator bool() const {
return pImpl != nullptr;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ void initJITBindings(PyObject* module) {
[](bool profiling_flag) { getProfilingMode() = profiling_flag; })
.def(
"_jit_set_first_class_mode",
[](bool enabled) { script::setRunAsFirstClass(enabled); })
[](bool enabled) { script::getFirstClassMode() = enabled; })
.def(
"_jit_fuser_get_fused_kernel_code",
[](Graph& g, std::vector<at::Tensor> inps) {
Expand Down
Loading

0 comments on commit ea822d9

Please sign in to comment.