Skip to content

Commit 67e4113

Browse files
committed
Access JIT decomposition utils through an interface
1 parent 4ca594e commit 67e4113

File tree

5 files changed

+96
-3
lines changed

5 files changed

+96
-3
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ core_trainer_sources = [
166166
"torch/csrc/autograd/saved_variable.cpp",
167167
"torch/csrc/autograd/variable.cpp",
168168
"torch/csrc/autograd/utils/warnings.cpp",
169+
"torch/csrc/jit_decomp_interface.cpp",
169170
"torch/csrc/jit/frontend/name_mangler.cpp",
170171
"torch/csrc/jit/ir/type_hashing.cpp",
171172
"torch/csrc/jit/serialization/pickler.cpp",

torch/csrc/autograd/VariableTypeUtilsDependOnOps.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

3-
#include <torch/csrc/jit/runtime/decomposition_registry.h>
3+
#include <ATen/core/boxing/KernelFunction.h>
4+
#include <ATen/core/dispatch/Dispatcher.h>
5+
#include <torch/csrc/jit_decomp_interface.h>
46

57
// This is the set of helpers in VariableTypeUtils have a dependency on
68
// native_functions.yaml meaning the file will need to be re-compiled every time
@@ -12,14 +14,28 @@ namespace torch {
1214
namespace autograd {
1315
namespace impl {
1416

17+
class MyFunctor final : public c10::OperatorKernel {
18+
public:
19+
MyFunctor(JitDecompInterface* fns) : fns_(fns){};
20+
21+
void operator()(
22+
const c10::OperatorHandle& op,
23+
c10::DispatchKeySet ks,
24+
torch::jit::Stack* stack) {
25+
fns_->run_jit_decomposition_(op, stack);
26+
}
27+
JitDecompInterface* fns_;
28+
};
29+
1530
// Depends on torch/csrc/jit/ir/ir.h -> aten/src/ATen/core/interned_strings.h
1631
template <class Return, class... Args>
1732
Return run_jit_decomposition_with_args_for_jvp(
1833
c10::string_view name,
1934
const c10::OperatorHandle& opHandle,
2035
c10::DispatchKeySet dispatchKeySet,
2136
Args&&... args) {
22-
bool has_decomp = jit::has_jit_decomposition(opHandle.schema());
37+
JitDecompInterface* fns = getJitDecomp();
38+
bool has_decomp = fns->has_jit_decomposition_(opHandle.schema());
2339

2440
TORCH_CHECK_NOT_IMPLEMENTED(
2541
has_decomp,
@@ -33,7 +49,8 @@ Return run_jit_decomposition_with_args_for_jvp(
3349
"PYTORCH_JIT=0 is set, some operators may no longer be used with forward AD.");
3450

3551
return c10::KernelFunction::makeFromBoxedKernel(
36-
c10::BoxedKernel::makeFromFunction<&jit::run_jit_decomposition>())
52+
c10::BoxedKernel::makeFromFunctor(
53+
std::make_unique<MyFunctor>(fns)))
3754
.call<Return, Args...>(
3855
opHandle, dispatchKeySet, std::forward<Args>(args)...);
3956
}

torch/csrc/jit/runtime/decomposition_registry.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <torch/csrc/jit/passes/inliner.h>
1414
#include <torch/csrc/jit/passes/peephole.h>
1515
#include <torch/csrc/jit/runtime/graph_executor.h>
16+
#include <torch/csrc/jit_decomp_interface.h>
1617
#include <memory>
1718
#include <unordered_map>
1819

@@ -160,6 +161,26 @@ void RegisterDecomposition(
160161
schema_to_decomposition[&schema] = g;
161162
}
162163

164+
struct JitDecomp final : torch::autograd::impl::JitDecompInterface {
165+
bool has_jit_decomposition_(const c10::FunctionSchema& schema) const override;
166+
void run_jit_decomposition_(
167+
const c10::OperatorHandle& op,
168+
torch::jit::Stack* stack) const override;
169+
};
170+
171+
JitDecomp jitDecomp;
172+
torch::autograd::impl::JitDecompRegisterer registerJitDecomp(&jitDecomp);
173+
174+
void JitDecomp::run_jit_decomposition_(
175+
const c10::OperatorHandle& op,
176+
torch::jit::Stack* stack) const {
177+
run_jit_decomposition(op, stack);
178+
}
179+
180+
bool JitDecomp::has_jit_decomposition_(const FunctionSchema& schema) const {
181+
return has_jit_decomposition(schema);
182+
}
183+
163184
void run_jit_decomposition(
164185
const c10::OperatorHandle& op,
165186
torch::jit::Stack* stack) {

torch/csrc/jit_decomp_interface.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <torch/csrc/jit_decomp_interface.h>
2+
3+
namespace torch {
4+
namespace autograd {
5+
namespace impl {
6+
7+
namespace {
8+
JitDecompInterface* fns = nullptr;
9+
}
10+
11+
void setJitDecompInterface(JitDecompInterface* f) {
12+
fns = f;
13+
}
14+
JitDecompInterface* getJitDecomp() {
15+
TORCH_CHECK(
16+
fns,
17+
"Support for JIT decompositions has not been loaded; have you linked against TBD?")
18+
return fns;
19+
}
20+
21+
} // namespace impl
22+
} // namespace autograd
23+
} // namespace torch

torch/csrc/jit_decomp_interface.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include <ATen/core/Tensor.h>
4+
#include <ATen/core/function_schema.h>
5+
#include <c10/macros/Export.h>
6+
7+
namespace torch {
8+
namespace autograd {
9+
namespace impl {
10+
11+
struct TORCH_API JitDecompInterface {
12+
virtual ~JitDecompInterface() = default;
13+
virtual bool has_jit_decomposition_(
14+
const c10::FunctionSchema& schema) const = 0;
15+
virtual void run_jit_decomposition_(
16+
const c10::OperatorHandle& op,
17+
jit::Stack* stack) const = 0;
18+
};
19+
20+
TORCH_API void setJitDecompInterface(JitDecompInterface* fns);
21+
TORCH_API JitDecompInterface* getJitDecomp();
22+
23+
struct TORCH_API JitDecompRegisterer {
24+
explicit JitDecompRegisterer(JitDecompInterface* fns) {
25+
setJitDecompInterface(fns);
26+
}
27+
};
28+
29+
} // namespace impl
30+
} // namespace autograd
31+
} // namespace torch

0 commit comments

Comments
 (0)