@@ -36,44 +36,40 @@ DEFINE_ARITHMETIC_FN(pow, POW);
36
36
37
37
ValueRef add_arithmetic_node (
38
38
ComputeGraph& graph,
39
- const ValueRef t1 ,
40
- const ValueRef t2 ,
39
+ const ValueRef in1 ,
40
+ const ValueRef in2 ,
41
41
const float alpha,
42
42
const arithmetic::OpType optype,
43
43
const int64_t shared_object_idx) {
44
- std::vector<int64_t > t1_sizes = graph.get_val_sizes (t1 );
45
- api::ScalarType t1_dtype = graph.get_val_dtype (t1 );
44
+ std::vector<int64_t > in1_sizes = graph.get_val_sizes (in1 );
45
+ api::ScalarType in1_dtype = graph.get_val_dtype (in1 );
46
46
47
- ValueRef out = graph.add_tensor (t1_sizes, t1_dtype , shared_object_idx);
48
- add_arithmetic_node (graph, t1, t2 , out, alpha, optype);
47
+ ValueRef out = graph.add_tensor (in1_sizes, in1_dtype , shared_object_idx);
48
+ add_arithmetic_node (graph, in1, in2 , out, alpha, optype);
49
49
return out;
50
50
}
51
51
52
+ // TODO(T181006464): Move to Utils when we remove ArithmeticPrepack.
53
+ ValueRef prepack_if_tensor_ref (ComputeGraph& graph, const ValueRef v) {
54
+ if (graph.get_val (v).isTensor ()) {
55
+ return v;
56
+ } else {
57
+ TensorRef& tRef = graph.get_val (v).toTensorRef ();
58
+ ValueRef vTen = graph.add_tensor (tRef.sizes , tRef.dtype );
59
+ graph.prepack_nodes ().emplace_back (new ArithmeticPrepack (v, vTen));
60
+ return vTen;
61
+ }
62
+ }
63
+
52
64
void add_arithmetic_node (
53
65
ComputeGraph& graph,
54
- const ValueRef t1 ,
55
- const ValueRef t2 ,
66
+ const ValueRef in1 ,
67
+ const ValueRef in2 ,
56
68
const ValueRef out,
57
69
const float alpha,
58
70
const arithmetic::OpType optype) {
59
- // Prepacking first arg (if needed)
60
- ValueRef arg1 = t1;
61
- if (graph.get_val (t1).isTensorRef ()) {
62
- TensorRef& t1_asref = graph.get_val (t1).toTensorRef ();
63
- ValueRef t1_vten = graph.add_tensor (t1_asref.sizes , t1_asref.dtype );
64
- graph.prepack_nodes ().emplace_back (new ArithmeticPrepack (t1, t1_vten));
65
- arg1 = t1_vten;
66
- }
67
- VK_CHECK_COND (graph.get_val (arg1).isTensor ());
68
- // Prepacking second arg (if needed)
69
- ValueRef arg2 = t2;
70
- if (graph.get_val (t2).isTensorRef ()) {
71
- TensorRef& t2_asref = graph.get_val (t2).toTensorRef ();
72
- ValueRef t2_vten = graph.add_tensor (t2_asref.sizes , t2_asref.dtype );
73
- graph.prepack_nodes ().emplace_back (new ArithmeticPrepack (t2, t2_vten));
74
- arg2 = t2_vten;
75
- }
76
- VK_CHECK_COND (graph.get_val (arg2).isTensor ());
71
+ ValueRef arg1 = prepack_if_tensor_ref (graph, in1);
72
+ ValueRef arg2 = prepack_if_tensor_ref (graph, in2);
77
73
78
74
graph.execute_nodes ().emplace_back (
79
75
new ArithmeticNode (arg1, arg2, out, alpha, optype));
@@ -97,12 +93,12 @@ void ArithmeticPrepack::encode(ComputeGraph* graph) const {
97
93
}
98
94
99
95
ArithmeticNode::ArithmeticNode (
100
- const ValueRef t1 ,
101
- const ValueRef t2 ,
96
+ const ValueRef in1 ,
97
+ const ValueRef in2 ,
102
98
const ValueRef out,
103
99
const float alpha,
104
100
const arithmetic::OpType optype)
105
- : ExecuteNode({t1, t2 }, {out}), alpha_(alpha), optype_(optype) {}
101
+ : ExecuteNode({in1, in2 }, {out}), alpha_(alpha), optype_(optype) {}
106
102
107
103
void ArithmeticNode::encode (ComputeGraph* graph) const {
108
104
vTensor& in1 = graph->get_val (inputs_[0 ]).toTensor ();
0 commit comments