Skip to content

Commit 73de30c

Browse files
yiming0416facebook-github-bot
authored andcommitted
[nativert] Move PrimKernelRegistry to PyTorch core (pytorch#156506)
Summary: Pull Request resolved: pytorch#156506 Torch Native Runtime RFC: pytorch/rfcs#72 PrimKernelRegistry manages a small subset of kernel registry in NativeRT. Including ListPack, ListUnpack, Input, Output, VarConcat, VarStack. Test Plan: Internal unittests Reviewed By: zhxchen17 Differential Revision: D77034945
1 parent 070e580 commit 73de30c

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ libtorch_nativert_sources = [
611611
"torch/nativert/kernels/C10Kernel.cpp",
612612
"torch/nativert/kernels/AutoFunctionalizeKernel.cpp",
613613
"torch/nativert/kernels/HigherOrderKernel.cpp",
614+
"torch/nativert/kernels/PrimKernelRegistry.cpp",
614615
]
615616

616617
torch_mobile_tracer_sources = [
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#include <ATen/record_function.h>
2+
3+
#include <ATen/CPUFunctions.h>
4+
#include <c10/core/ScalarType.h>
5+
#include <c10/util/irange.h>
6+
#include <torch/csrc/jit/runtime/static/ops.h>
7+
8+
#include <c10/util/Enumerate.h>
9+
#include <torch/nativert/kernels/PrimKernelRegistry.h>
10+
11+
namespace torch::nativert {
12+
13+
C10_DEFINE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*);
14+
15+
namespace {
16+
17+
class OpKernel_prim_listpack : public OpKernel {
18+
public:
19+
explicit OpKernel_prim_listpack(const Node* node)
20+
: OpKernel(
21+
node,
22+
std::nullopt,
23+
torch::nativert::OpKernelKind::kPrimKernel) {
24+
auto listType = node->outputs()[0]->type();
25+
switch (listType.kind()) {
26+
case Type::Kind::TensorList:
27+
type_ = c10::TensorType::get();
28+
break;
29+
case Type::Kind::SymIntList:
30+
type_ = c10::IntType::get();
31+
break;
32+
case Type::Kind::OptionalTensorList:
33+
type_ = c10::OptionalType::create(c10::TensorType::get());
34+
break;
35+
default:
36+
TORCH_CHECK(false, "Unsupported list type: ", listType);
37+
}
38+
}
39+
40+
void computeInternal(ExecutionFrame& executionFrame) const override final {
41+
RECORD_USER_SCOPE("sigmoid::OpKernel_prim_listpack");
42+
c10::List<c10::IValue> list(type_);
43+
list.reserve(numInputs());
44+
for (size_t i = 0; i < numInputs(); ++i) {
45+
if (KernelInput(i).isNone()) {
46+
list.emplace_back();
47+
} else {
48+
list.push_back(KernelInput(i));
49+
}
50+
}
51+
KernelOutput(0) = std::move(list);
52+
}
53+
54+
private:
55+
c10::TypePtr type_;
56+
};
57+
58+
} // namespace
59+
60+
C10_REGISTER_TYPED_CLASS(
61+
PrimKernelRegistry,
62+
"prim.ListPack",
63+
OpKernel_prim_listpack);
64+
65+
REGISTER_PRIM_KERNEL("prim.ListUnpack", prim_listunpack, {
66+
RECORD_USER_SCOPE("sigmoid::OpKernel_prim_listunpack");
67+
auto inputListRef = KernelInput(0).toListRef();
68+
for (const auto& [i, ivalue] : c10::enumerate(inputListRef)) {
69+
KernelOutput(i) = ivalue;
70+
}
71+
});
72+
73+
// Noop for input and output
74+
REGISTER_PRIM_KERNEL("prim.Input", prim_input, {});
75+
REGISTER_PRIM_KERNEL("prim.Output", prim_output, {});
76+
77+
namespace {
78+
79+
class OpKernel_variadic_concat : public OpKernel {
80+
public:
81+
explicit OpKernel_variadic_concat(const Node* node)
82+
: OpKernel(
83+
node,
84+
std::nullopt,
85+
torch::nativert::OpKernelKind::kPrimKernel) {
86+
dim_ = node_->attributes().size() > 0
87+
? constantToIValue(node_->getAttribute("dim").value).toInt()
88+
: 0;
89+
}
90+
void computeInternal(ExecutionFrame& executionFrame) const override final {
91+
{
92+
const size_t numNodeInps = numInputs();
93+
auto numCatInps = numNodeInps;
94+
auto dim = dim_;
95+
if (KernelInput(numCatInps - 1).isInt()) {
96+
dim = KernelInput(numCatInps - 1).toInt();
97+
numCatInps--;
98+
}
99+
std::vector<at::Tensor> inputs(numCatInps);
100+
for (const auto i : c10::irange(numCatInps)) {
101+
inputs[i] = KernelInput(i).toTensor();
102+
}
103+
104+
if (KernelOutput(0).isNone()) {
105+
KernelOutput(0) = at::cpu::cat(inputs, dim);
106+
return;
107+
}
108+
auto& out_t = KernelOutput(0).toTensor();
109+
fastResizeToZero(out_t);
110+
at::cpu::cat_outf(inputs, dim, out_t);
111+
}
112+
}
113+
114+
private:
115+
int dim_;
116+
};
117+
118+
} // namespace
119+
120+
C10_REGISTER_TYPED_CLASS(
121+
PrimKernelRegistry,
122+
"prim.VarConcat",
123+
OpKernel_variadic_concat);
124+
125+
namespace {
126+
127+
class OpKernel_variadic_stack : public OpKernel {
128+
public:
129+
explicit OpKernel_variadic_stack(const Node* node)
130+
: OpKernel(
131+
node,
132+
std::nullopt,
133+
torch::nativert::OpKernelKind::kPrimKernel) {
134+
dim_ = node_->attributes().size() > 0
135+
? constantToIValue(node_->getAttribute("dim").value).toInt()
136+
: 0;
137+
}
138+
void computeInternal(ExecutionFrame& executionFrame) const override final {
139+
{
140+
const size_t numNodeInps = numInputs();
141+
auto numStackInps = numNodeInps;
142+
auto dim = dim_;
143+
if (KernelInput(numStackInps - 1).isInt()) {
144+
dim = KernelInput(numStackInps - 1).toInt();
145+
numStackInps--;
146+
}
147+
std::vector<at::Tensor> inputs(numStackInps);
148+
for (const auto i : c10::irange(numStackInps)) {
149+
inputs[i] = KernelInput(i).toTensor();
150+
}
151+
auto& out = KernelOutput(0);
152+
if (out.isNone()) {
153+
out = at::native::_stack_cpu(inputs, dim);
154+
return;
155+
}
156+
auto& out_t = out.toTensor();
157+
fastResizeToZero(out_t);
158+
at::native::_stack_out_cpu(inputs, dim, out_t);
159+
}
160+
}
161+
162+
private:
163+
int64_t dim_;
164+
};
165+
} // namespace
166+
167+
C10_REGISTER_TYPED_CLASS(
168+
PrimKernelRegistry,
169+
"prim.VarStack",
170+
OpKernel_variadic_stack);
171+
172+
} // namespace torch::nativert
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#pragma once
2+
3+
#include <torch/nativert/executor/OpKernel.h>
4+
#include <torch/nativert/graph/Graph.h>
5+
#include <torch/nativert/kernels/C10Kernel.h>
6+
7+
namespace torch::nativert {
8+
9+
#define KernelInput(id) input(id, executionFrame)
10+
#define KernelOutput(id) output(id, executionFrame)
11+
12+
TORCH_DECLARE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*);
13+
14+
#define REGISTER_PRIM_KERNEL(name, id, ...) \
15+
class OpKernel_##id : public OpKernel { \
16+
public: \
17+
OpKernel_##id(const Node* node) \
18+
: OpKernel( \
19+
node, \
20+
std::nullopt, \
21+
torch::nativert::OpKernelKind::kPrimKernel) {} \
22+
void computeInternal( \
23+
ExecutionFrame& executionFrame) const override final { \
24+
__VA_ARGS__; \
25+
} \
26+
}; \
27+
C10_REGISTER_TYPED_CLASS(PrimKernelRegistry, name, OpKernel_##id);
28+
29+
inline bool checkResizedDataPtr(at::Tensor& t) {
30+
auto const prev_data_ptr = t.data_ptr();
31+
t.resize_({0});
32+
return prev_data_ptr == t.data_ptr();
33+
}
34+
35+
inline void fastResizeToZero(at::Tensor& t) {
36+
t.unsafeGetTensorImpl()->set_sizes_contiguous({0});
37+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t));
38+
}
39+
40+
} // namespace torch::nativert

0 commit comments

Comments
 (0)