Skip to content

Commit 6a7d3a9

Browse files
committed
[ET-VK] Nit Arithmetic cleanup
Facilitate code review before the big refactoring. Create a `maybe_prepack()` helper and improve variable naming. Differential Revision: [D54400674](https://our.internmc.facebook.com/intern/diff/D54400674/) [ghstack-poisoned]
1 parent bc969fc commit 6a7d3a9

File tree

2 files changed

+31
-35
lines changed

2 files changed

+31
-35
lines changed

backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,44 +36,40 @@ DEFINE_ARITHMETIC_FN(pow, POW);
3636

3737
ValueRef add_arithmetic_node(
3838
ComputeGraph& graph,
39-
const ValueRef t1,
40-
const ValueRef t2,
39+
const ValueRef in1,
40+
const ValueRef in2,
4141
const float alpha,
4242
const arithmetic::OpType optype,
4343
const int64_t shared_object_idx) {
44-
std::vector<int64_t> t1_sizes = graph.get_val_sizes(t1);
45-
api::ScalarType t1_dtype = graph.get_val_dtype(t1);
44+
std::vector<int64_t> in1_sizes = graph.get_val_sizes(in1);
45+
api::ScalarType in1_dtype = graph.get_val_dtype(in1);
4646

47-
ValueRef out = graph.add_tensor(t1_sizes, t1_dtype, shared_object_idx);
48-
add_arithmetic_node(graph, t1, t2, out, alpha, optype);
47+
ValueRef out = graph.add_tensor(in1_sizes, in1_dtype, shared_object_idx);
48+
add_arithmetic_node(graph, in1, in2, out, alpha, optype);
4949
return out;
5050
}
5151

52+
// TODO(T181006464): Move to Utils when we remove ArithmeticPrepack.
53+
ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v) {
54+
if (graph.get_val(v).isTensor()) {
55+
return v;
56+
} else {
57+
TensorRef& tRef = graph.get_val(v).toTensorRef();
58+
ValueRef vTen = graph.add_tensor(tRef.sizes, tRef.dtype);
59+
graph.prepack_nodes().emplace_back(new ArithmeticPrepack(v, vTen));
60+
return vTen;
61+
}
62+
}
63+
5264
void add_arithmetic_node(
5365
ComputeGraph& graph,
54-
const ValueRef t1,
55-
const ValueRef t2,
66+
const ValueRef in1,
67+
const ValueRef in2,
5668
const ValueRef out,
5769
const float alpha,
5870
const arithmetic::OpType optype) {
59-
// Prepacking first arg (if needed)
60-
ValueRef arg1 = t1;
61-
if (graph.get_val(t1).isTensorRef()) {
62-
TensorRef& t1_asref = graph.get_val(t1).toTensorRef();
63-
ValueRef t1_vten = graph.add_tensor(t1_asref.sizes, t1_asref.dtype);
64-
graph.prepack_nodes().emplace_back(new ArithmeticPrepack(t1, t1_vten));
65-
arg1 = t1_vten;
66-
}
67-
VK_CHECK_COND(graph.get_val(arg1).isTensor());
68-
// Prepacking second arg (if needed)
69-
ValueRef arg2 = t2;
70-
if (graph.get_val(t2).isTensorRef()) {
71-
TensorRef& t2_asref = graph.get_val(t2).toTensorRef();
72-
ValueRef t2_vten = graph.add_tensor(t2_asref.sizes, t2_asref.dtype);
73-
graph.prepack_nodes().emplace_back(new ArithmeticPrepack(t2, t2_vten));
74-
arg2 = t2_vten;
75-
}
76-
VK_CHECK_COND(graph.get_val(arg2).isTensor());
71+
ValueRef arg1 = prepack_if_tensor_ref(graph, in1);
72+
ValueRef arg2 = prepack_if_tensor_ref(graph, in2);
7773

7874
graph.execute_nodes().emplace_back(
7975
new ArithmeticNode(arg1, arg2, out, alpha, optype));
@@ -97,12 +93,12 @@ void ArithmeticPrepack::encode(ComputeGraph* graph) const {
9793
}
9894

9995
ArithmeticNode::ArithmeticNode(
100-
const ValueRef t1,
101-
const ValueRef t2,
96+
const ValueRef in1,
97+
const ValueRef in2,
10298
const ValueRef out,
10399
const float alpha,
104100
const arithmetic::OpType optype)
105-
: ExecuteNode({t1, t2}, {out}), alpha_(alpha), optype_(optype) {}
101+
: ExecuteNode({in1, in2}, {out}), alpha_(alpha), optype_(optype) {}
106102

107103
void ArithmeticNode::encode(ComputeGraph* graph) const {
108104
vTensor& in1 = graph->get_val(inputs_[0]).toTensor();

backends/vulkan/runtime/graph/ops/impl/Arithmetic.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ DECLARE_OP_FN(pow);
2929

3030
ValueRef add_arithmetic_node(
3131
ComputeGraph& graph,
32-
const ValueRef t1,
33-
const ValueRef t2,
32+
const ValueRef in1,
33+
const ValueRef in2,
3434
const float alpha,
3535
const arithmetic::OpType optype,
3636
const int64_t shared_object_idx = -1);
3737

3838
void add_arithmetic_node(
3939
ComputeGraph& graph,
40-
const ValueRef t1,
41-
const ValueRef t2,
40+
const ValueRef in1,
41+
const ValueRef in2,
4242
const ValueRef out,
4343
const float alpha,
4444
const arithmetic::OpType optype);
@@ -53,8 +53,8 @@ class ArithmeticPrepack : public virtual PrepackNode {
5353
class ArithmeticNode : public virtual ExecuteNode {
5454
public:
5555
explicit ArithmeticNode(
56-
const ValueRef t1,
57-
const ValueRef t2,
56+
const ValueRef in1,
57+
const ValueRef in2,
5858
const ValueRef out,
5959
const float alpha,
6060
const arithmetic::OpType optype);

0 commit comments

Comments
 (0)