Skip to content

Commit dfb5f51

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Remove Functions.h/cpp (#2245)
Summary: bypass-github-export-checks Pull Request resolved: #2245 In D53982443, OperatorRegistry.h/cpp and Functions.h/cpp were both introduced, as they were split across the PT and ET repos, but now both are in ET. ## OperatorRegistry.cpp Here, we see all our operators. OPERATOR_ENTRY maps from Vulkan Dialect op name to the OpFunction, which have an op-specific name. Note that all OpFunction carry the same function signature. ## Functions.h/cpp -> Arithmetic.h/cpp We don't need another place to see all our operators. They will each reference one ops/impl file, so we group them accordingly in their ops/impl file. ## Nit Also, sort `add_arithmetic_node()` declarations according to their execution order. ghstack-source-id: 217394062 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D54370467 fbshipit-source-id: d9e82896577610d5dcee3e0bf7f662e69d59e1db
1 parent 4603613 commit dfb5f51

File tree

5 files changed

+53
-71
lines changed

5 files changed

+53
-71
lines changed

backends/vulkan/runtime/graph/ops/Functions.cpp

Lines changed: 0 additions & 40 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11-
#include <executorch/backends/vulkan/runtime/graph/ops/Functions.h>
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>
1212

1313
namespace at {
1414
namespace native {

backends/vulkan/runtime/graph/ops/Functions.h renamed to backends/vulkan/runtime/graph/ops/Utils.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,16 @@
1010

1111
#ifdef USE_VULKAN_API
1212

13+
#include <ATen/native/vulkan/impl/Arithmetic.h>
14+
1315
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1416

1517
namespace at {
1618
namespace native {
1719
namespace vulkan {
1820

19-
#define DEFINE_OP_FN(name) \
20-
ValueRef name(ComputeGraph& graph, const std::vector<ValueRef>& args);
21-
22-
DEFINE_OP_FN(add);
23-
DEFINE_OP_FN(sub);
24-
DEFINE_OP_FN(mul);
25-
DEFINE_OP_FN(div);
26-
DEFINE_OP_FN(floor_div);
27-
DEFINE_OP_FN(pow);
21+
#define DECLARE_OP_FN(function) \
22+
ValueRef function(ComputeGraph& graph, const std::vector<ValueRef>& args);
2823

2924
} // namespace vulkan
3025
} // namespace native

backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,39 @@ namespace at {
1616
namespace native {
1717
namespace vulkan {
1818

19+
#define DEFINE_ARITHMETIC_FN(function, op_type) \
20+
ValueRef function(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
21+
return add_arithmetic_node( \
22+
graph, \
23+
args[0], \
24+
args[1], \
25+
args[2], \
26+
arithmetic::OpType::op_type, \
27+
args[3]); \
28+
}
29+
30+
DEFINE_ARITHMETIC_FN(add, ADD);
31+
DEFINE_ARITHMETIC_FN(sub, SUB);
32+
DEFINE_ARITHMETIC_FN(mul, MUL);
33+
DEFINE_ARITHMETIC_FN(div, DIV);
34+
DEFINE_ARITHMETIC_FN(floor_div, FLOOR_DIV);
35+
DEFINE_ARITHMETIC_FN(pow, POW);
36+
37+
ValueRef add_arithmetic_node(
38+
ComputeGraph& graph,
39+
const ValueRef t1,
40+
const ValueRef t2,
41+
const float alpha,
42+
const arithmetic::OpType optype,
43+
const int64_t shared_object_idx) {
44+
std::vector<int64_t> t1_sizes = graph.get_val_sizes(t1);
45+
api::ScalarType t1_dtype = graph.get_val_dtype(t1);
46+
47+
ValueRef out = graph.add_tensor(t1_sizes, t1_dtype, shared_object_idx);
48+
add_arithmetic_node(graph, t1, t2, out, alpha, optype);
49+
return out;
50+
}
51+
1952
void add_arithmetic_node(
2053
ComputeGraph& graph,
2154
const ValueRef t1,
@@ -46,21 +79,6 @@ void add_arithmetic_node(
4679
new ArithmeticNode(arg1, arg2, out, alpha, optype));
4780
}
4881

49-
ValueRef add_arithmetic_node(
50-
ComputeGraph& graph,
51-
const ValueRef t1,
52-
const ValueRef t2,
53-
const float alpha,
54-
const arithmetic::OpType optype,
55-
const int64_t shared_object_idx) {
56-
std::vector<int64_t> t1_sizes = graph.get_val_sizes(t1);
57-
api::ScalarType t1_dtype = graph.get_val_dtype(t1);
58-
59-
ValueRef out = graph.add_tensor(t1_sizes, t1_dtype, shared_object_idx);
60-
add_arithmetic_node(graph, t1, t2, out, alpha, optype);
61-
return out;
62-
}
63-
6482
ArithmeticPrepack::ArithmeticPrepack(const ValueRef tref, const ValueRef packed)
6583
: PrepackNode(tref, packed) {}
6684

backends/vulkan/runtime/graph/ops/impl/Arithmetic.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,34 @@
1414

1515
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1616

17+
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
18+
1719
namespace at {
1820
namespace native {
1921
namespace vulkan {
2022

21-
void add_arithmetic_node(
23+
DECLARE_OP_FN(add);
24+
DECLARE_OP_FN(sub);
25+
DECLARE_OP_FN(mul);
26+
DECLARE_OP_FN(div);
27+
DECLARE_OP_FN(floor_div);
28+
DECLARE_OP_FN(pow);
29+
30+
ValueRef add_arithmetic_node(
2231
ComputeGraph& graph,
2332
const ValueRef t1,
2433
const ValueRef t2,
25-
const ValueRef out,
2634
const float alpha,
27-
const arithmetic::OpType optype);
35+
const arithmetic::OpType optype,
36+
const int64_t shared_object_idx = -1);
2837

29-
ValueRef add_arithmetic_node(
38+
void add_arithmetic_node(
3039
ComputeGraph& graph,
3140
const ValueRef t1,
3241
const ValueRef t2,
42+
const ValueRef out,
3343
const float alpha,
34-
const arithmetic::OpType optype,
35-
const int64_t shared_object_idx = -1);
44+
const arithmetic::OpType optype);
3645

3746
class ArithmeticPrepack : public virtual PrepackNode {
3847
public:

0 commit comments

Comments
 (0)