-
Notifications
You must be signed in to change notification settings - Fork 833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add oneflow.tensordot interface #7968
Changes from 14 commits
49120fe
7b9fada
a9dce8f
517f8e6
9d2ee97
2432572
962fa05
7e81b00
a6f02ce
26ef1cd
cca8a7f
220c4b7
18ef83b
a3abefe
70be1f4
5f8aaeb
d27873f
8071a65
3bd49ec
ced6c78
5f3bb2c
71305b5
6b07e23
4f25207
558cff7
dbbebd3
2fd14df
16f73b5
18e6f73
972971e
71a6bb9
7cf5e0c
c928ef0
b6a9dd7
991683c
f175d29
756d93c
4869ecc
5c6b0f8
8aea698
2d5e2bc
819d605
2b3e3d8
ede5583
d1cdce6
2971dae
247f05e
6d640c7
b0077ac
33c9453
3739493
a7fbd5e
ddf7e60
650729e
5036097
7445b4a
486c08b
48532d8
663a500
f497880
b1f52fe
256d28b
48655ba
02972d7
958ccf8
1537951
a6ced7f
89c04aa
e8af06c
bd2a82f
3f38b77
f355c87
e7866f4
aa746a1
19ce92d
168b255
c7d6eae
8a67d56
22f3e21
5a9ec24
98c9fa8
e0df9ad
33b5de7
4b8c524
a25138c
231d277
dd631c1
fa5daba
58b8b47
e885bfd
59bf754
bf20b05
9d0f40f
7a74f4d
a2cee98
eb58d8c
1ff7c81
143b410
7572c88
d535414
af8cd54
712813f
5ccd4f5
aaefd02
61520ed
3d0ad3e
09b9a31
48810f6
033233e
874783d
43e92e9
addcdde
9fb5b26
663b0fe
1529761
4b0b0a8
3af98c1
9e668ef
198d72d
d9e96b2
dff3d56
1730f23
ef2f618
4883fb9
c5307f3
eb95035
a2c6e47
94152d5
32e2196
e0178b2
d2569e8
157de26
fb305a1
cc9728f
8ea6194
47280df
dfe09ca
e10f7b6
538b10c
76d2c63
390d207
55d4f60
53db592
452f4ee
9ab61be
cc5ba4a
0b7f042
ed6158a
6f6f12a
9ffd573
83c4f7d
33e0d0a
b99408b
bed895d
e26efbb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,6 +148,7 @@ oneflow | |
tan, | ||
tanh, | ||
tensor, | ||
tensordot, | ||
tile, | ||
transpose, | ||
t, | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -264,6 +264,93 @@ class BatchMatMulFunctor { | |||||||
std::shared_ptr<OpExpr> batch_matmul_op_; | ||||||||
}; | ||||||||
|
||||||||
class TensorDotFunctor { | ||||||||
public: | ||||||||
Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& a, const std::shared_ptr<Tensor>& b, | ||||||||
const std::vector<int64_t>& dims_a, | ||||||||
const std::vector<int64_t>& dims_b) const { | ||||||||
CHECK_EQ_OR_RETURN(dims_a.size(), dims_b.size()) | ||||||||
<< "dims1 and dims2 must have same size, got " << dims_a.size() << " and " << dims_b.size(); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这儿, dim_a, dim_b,还未声明。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 抱歉,改得着急了。已改正。 |
||||||||
|
||||||||
if (dims_a.empty() && dims_b.empty()) { | ||||||||
std::vector<int64_t> shape_sum; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个地方可以构造DimVector,然后更新值即可,不需要用std::vector后面又构造一次DimVector:
Suggested change
|
||||||||
shape_sum.reserve(a->shape()->NumAxes() + b->shape()->NumAxes()); | ||||||||
for (int64_t i = 0; i < a->shape()->NumAxes(); i++) { | ||||||||
shape_sum.emplace_back(a->shape()->At(i)); | ||||||||
} | ||||||||
for (int64_t i = 0; i < b->shape()->NumAxes(); i++) { | ||||||||
shape_sum.emplace_back(b->shape()->At(i)); | ||||||||
} | ||||||||
auto reshape_a = JUST(Reshape(a, Shape(DimVector{-1, 1}))); | ||||||||
auto reshape_b = JUST(Reshape(b, Shape(DimVector{1, -1}))); | ||||||||
return JUST(Reshape(JUST(functional::MatMul(reshape_a, reshape_b, false, false, 1.0)), | ||||||||
Shape(DimVector(shape_sum.begin(), shape_sum.end())))); | ||||||||
} | ||||||||
std::vector<bool> if_dot_dims_a(a->shape()->NumAxes(), false); | ||||||||
std::vector<bool> if_dot_dims_b(b->shape()->NumAxes(), false); | ||||||||
for (auto i : dims_a) if_dot_dims_a[i] = true; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 要么知道auto的具体类型,要么用const auto不然不要使用auto,可以检查一下 |
||||||||
for (auto i : dims_b) if_dot_dims_b[i] = true; | ||||||||
|
||||||||
std::vector<int32_t> broadcast_dims_a, broadcast_dims_b; | ||||||||
for (int64_t i = 0; i < dims_a.size(); i++) { | ||||||||
int64_t size_a = a->shape()->At(dims_a[i]); | ||||||||
int64_t size_b = b->shape()->At(dims_b[i]); | ||||||||
if (size_a == 1) { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个判断条件是不是再限制一下,if (size_a == 1 && size_b > 1) ?下面那个判断类似 |
||||||||
broadcast_dims_b.emplace_back(dims_b[i]); | ||||||||
} else if (size_b == 1) { | ||||||||
broadcast_dims_a.emplace_back(dims_a[i]); | ||||||||
} else { | ||||||||
CHECK_EQ_OR_RETURN(size_a, size_b) << "The corresponding dim must be equal, got " << size_a | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同,需要规范报错信息写法以及补充测试 |
||||||||
<< " in tensor a and " << size_b << " in tensor b"; | ||||||||
} | ||||||||
} | ||||||||
auto reduced_a = a; | ||||||||
auto reduced_b = b; | ||||||||
if (!broadcast_dims_a.empty()) | ||||||||
reduced_a = JUST(functional::ReduceSum(a, broadcast_dims_a, true)); | ||||||||
if (!broadcast_dims_b.empty()) | ||||||||
reduced_b = JUST(functional::ReduceSum(b, broadcast_dims_b, true)); | ||||||||
|
||||||||
int64_t rsize_a = 1, rsize_b = 1; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这几个名字统一驼峰吧,不能过于随意了 |
||||||||
std::vector<int64_t> rshape_a, rshape_b; | ||||||||
std::vector<int32_t> pa, pb; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. vector如果预先长度可以确定,你可以直接构造,而不用再调用一次reserve了 |
||||||||
pa.reserve(a->shape()->NumAxes()); | ||||||||
pb.reserve(b->shape()->NumAxes()); | ||||||||
std::vector<int64_t> non_dot_dims; | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 361和357行的vector的类型可以统一 |
||||||||
for (int64_t i = 0; i < a->shape()->NumAxes(); i++) { | ||||||||
if (!if_dot_dims_a[i]) { | ||||||||
non_dot_dims.emplace_back(a->shape()->At(i)); | ||||||||
pa.emplace_back(i); | ||||||||
rsize_a *= reduced_a->shape()->At(i); | ||||||||
rshape_a.emplace_back(reduced_a->shape()->At(i)); | ||||||||
} | ||||||||
} | ||||||||
|
||||||||
for (auto i : dims_a) pa.emplace_back(i); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些地方是类似的问题,auto滥用 |
||||||||
for (auto i : dims_b) pb.emplace_back(i); | ||||||||
|
||||||||
for (int64_t i = 0; i < b->shape()->NumAxes(); i++) { | ||||||||
if (!if_dot_dims_b[i]) { | ||||||||
non_dot_dims.emplace_back(b->shape()->At(i)); | ||||||||
pb.emplace_back(i); | ||||||||
rsize_b *= reduced_b->shape()->At(i); | ||||||||
rshape_b.emplace_back(reduced_b->shape()->At(i)); | ||||||||
} | ||||||||
} | ||||||||
rshape_a.insert(rshape_a.end(), rshape_b.begin(), rshape_b.end()); | ||||||||
|
||||||||
int64_t dshape = 1; | ||||||||
for (auto i : dims_a) dshape *= reduced_a->shape()->At(i); | ||||||||
auto permute_a = | ||||||||
JUST(Reshape(JUST(Permute(reduced_a, pa)), Shape(DimVector({rsize_a, dshape})))); | ||||||||
auto permute_b = | ||||||||
JUST(Reshape(JUST(Permute(reduced_b, pb)), Shape(DimVector({dshape, rsize_b})))); | ||||||||
|
||||||||
return Reshape(JUST(functional::MatMul(permute_a, permute_b, false, false, 1.0)), | ||||||||
Shape(DimVector({rshape_a.begin(), rshape_a.end()}))); | ||||||||
} | ||||||||
}; | ||||||||
|
||||||||
class FusedMLPFunctor { | ||||||||
public: | ||||||||
FusedMLPFunctor() { | ||||||||
|
@@ -2721,6 +2808,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { | |||||||
m.add_functor<impl::DeConv3dFunctor>("Deconv3d"); | ||||||||
m.add_functor<impl::MatMulFunctor>("MatMul"); | ||||||||
m.add_functor<impl::BatchMatMulFunctor>("BatchMatMul"); | ||||||||
m.add_functor<impl::TensorDotFunctor>("TensorDot"); | ||||||||
m.add_functor<impl::FusedMLPFunctor>("FusedMLP"); | ||||||||
m.add_functor<impl::LayerNormFunctor>("LayerNorm"); | ||||||||
m.add_functor<impl::LayerNormAffineFunctor>("LayerNormAffine"); | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,7 +211,7 @@ class OneFlow_NormalizationAddReluBaseOp : OneFlow_BaseOp<"normalization_add_rel | |
#endif // GET_ONEFLOW_BASE_OP_DEFINITIONS | ||
|
||
// Group: BINARY | ||
// bias_add, cast_like, celu_grad, diag_grad, diagonal_grad, dot, dropout_grad, elementwise_maximum, elementwise_minimum, elu_grad, floordiv, gelu_grad, grid_sample, hardsigmoid_grad, hardshrink_grad, hardswish_grad, l1_l2_regularize_gradient, leaky_relu_grad, masked_fill, mish_grad, multiply, narrow_grad, pow, prelu, relu_grad, selu_grad, sigmoid_grad, silu_grad, softshrink_grad, threshold_grad, tf_prelu, unfold_tensor_grad, xdivy, xlogy | ||
// bias_add, cast_like, celu_grad, diag_grad, diagonal_grad, dot, dropout_grad, elementwise_maximum, elementwise_minimum, elu_grad, floordiv, gelu_grad, grid_sample, hardsigmoid_grad, hardshrink_grad, hardswish_grad, l1_l2_regularize_gradient, leaky_relu_grad, masked_fill, mish_grad, multiply, narrow_grad, pow, prelu, relu_grad, selu_grad, sigmoid_grad, silu_grad, softshrink_grad, tensor_dot, threshold_grad, tf_prelu, unfold_tensor_grad, xdivy, xlogy | ||
// Total: 34 | ||
|
||
#ifdef GET_ONEFLOW_BINARY_OP_DEFINITIONS | ||
|
@@ -659,6 +659,24 @@ def OneFlow_SiluGradOp : OneFlow_BaseOp<"silu_grad", [NoSideEffect, DeclareOpInt | |
let has_data_type_infer_fn = 1; | ||
} | ||
|
||
def OneFlow_TensorDotOp : OneFlow_BaseOp<"tensordot", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不适用OpBuilder构建的op不需要更新这里 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
let input = (ins | ||
OneFlow_Tensor:$x, | ||
OneFlow_Tensor:$y | ||
); | ||
let output = (outs | ||
OneFlow_Tensor:$z | ||
); | ||
let attrs = (ins | ||
DefaultValuedAttr<SI32Attr, "0">:$axis_a, | ||
DefaultValuedAttr<SI32Attr, "0">:$axis_b | ||
); | ||
let has_logical_tensor_desc_infer_fn = 1; | ||
let has_physical_tensor_desc_infer_fn = 1; | ||
let has_get_sbp_fn = 1; | ||
let has_data_type_infer_fn = 1; | ||
} | ||
|
||
def OneFlow_ThresholdGradOp : OneFlow_BaseOp<"threshold_grad", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> { | ||
let input = (ins | ||
OneFlow_Tensor:$x, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -364,6 +364,7 @@ def atexit_hook(hook): | |
from oneflow.nn.modules.global_cast import to_local_op as to_local | ||
from oneflow.nn.modules.where import where_op as where | ||
from oneflow.nn.modules.scatter import * | ||
from oneflow.nn.functional import tensordot | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个地方感觉怪怪的,是不需要flow.nn.functional.tensordot接口的,为什么要走这条导入链 |
||
from oneflow.ops.stateful_ops import StatefulOp as stateful_op | ||
from oneflow.ops.initializer_util import constant_initializer | ||
from oneflow.ops.initializer_util import glorot_normal_initializer | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,3 +68,4 @@ | |
from .argsort import * | ||
from .module import * | ||
from .util_ops import * | ||
from .tensordot import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import oneflow | ||
from oneflow.framework.docstr.utils import add_docstr | ||
|
||
add_docstr( | ||
oneflow.tensordot, | ||
r""" | ||
tensordot(a, b, dims=Union[int, Tensor, Tuple]) -> Tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里也不符合规范,可以看下其它算子参考自pytorch文档是怎么写的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
Compute tensor dot along given dimensions. | ||
|
||
Given two tensors a and b, and dims which has two list containing dim indices, `tensordot` traverses the two | ||
lists and calculate the tensor dot along every dim pair. | ||
|
||
Args: | ||
a: The input tensor to compute tensordot | ||
b: The input tensor to compute tensordot | ||
dims: int or array-like | ||
|
||
Returns: | ||
Oneflow.Tensor: The result tensor | ||
|
||
For example: | ||
|
||
.. code-block:: python | ||
|
||
>>> import oneflow as flow | ||
>>> a = flow.randn(3, 4, 5) | ||
>>> b = flow.randn(4, 5, 6) | ||
>>> flow.tensordot(a, b, dims=2).shape | ||
oneflow.Size([3, 6]) | ||
>>> b = flow.randn(5, 6, 7) | ||
>>> flow.tensordot(a, b, dims=1).shape | ||
oneflow.Size([12, 42]) | ||
>>> b = flow.randn(3, 4, 7) | ||
>>> flow.tensordot(a, b, dims=[[0, 1], [0, 1]]).shape | ||
oneflow.Size([5, 7]) | ||
|
||
""", | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import oneflow | ||
from typing import Union, List, Tuple | ||
|
||
def tensordot(a, b, dims: Union[int, List[List[int]]] = 2): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里输入的Union好像没有考虑Tensor,而且这个python文件并不是必要的,完全可以把这个Python wrap去掉效率会更好一些。可以参考一下flow.split是如何去掉python wrap |
||
if not isinstance(dims, (oneflow._oneflow_internal.Tensor, int, list, tuple)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
raise TypeError( | ||
f"oneflow.tensordot expects dims to be one of oneflow.Tensor, int, List[List[int]] or Tuple[List[int], List[int]], but got {type(dims)}" | ||
) | ||
|
||
if isinstance(dims, int): | ||
assert dims >= 0 and dims <= min(a.dim(), b.dim()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里需要给一点报错信息 |
||
dim_a = list(range(a.dim() - dims, a.dim())) | ||
dim_b = list(range(dims)) | ||
|
||
elif isinstance(dims, (list, tuple, oneflow._oneflow_internal.Tensor)): | ||
assert len(dims) == 2 | ||
dim_a = list(dims[0]) | ||
dim_b = list(dims[1]) | ||
assert ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
isinstance(dim_a[0], int) | ||
and isinstance(dim_b[0], int) | ||
and len(dim_a) == len(dim_b) | ||
) | ||
|
||
return oneflow._C.tensordot(a, b, dim_a, dim_b) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,64 @@ | ||||||
""" | ||||||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||||||
|
||||||
Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
you may not use this file except in compliance with the License. | ||||||
You may obtain a copy of the License at | ||||||
|
||||||
http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|
||||||
Unless required by applicable law or agreed to in writing, software | ||||||
distributed under the License is distributed on an "AS IS" BASIS, | ||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
See the License for the specific language governing permissions and | ||||||
limitations under the License. | ||||||
""" | ||||||
import unittest | ||||||
import numpy as np | ||||||
import oneflow as flow | ||||||
import oneflow.unittest | ||||||
from oneflow.test_utils.automated_test_util import * | ||||||
|
||||||
|
||||||
@flow.unittest.skip_unless_1n1d() | ||||||
class TestTensordot(flow.unittest.TestCase): | ||||||
zhongshsh marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
@autotest(check_graph=True) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
def test_tensordot_intdim(test_case): | ||||||
device = random_device() | ||||||
dims = random() | ||||||
dims_list = [random().to(int).value() for i in range(dims.to(int).value() + 3)] | ||||||
x = random_tensor( | ||||||
ndim=3, | ||||||
dim0=dims_list[0], | ||||||
dim1=dims_list[1], | ||||||
dim2=dims_list[2], | ||||||
).to(device) | ||||||
y = random_tensor( | ||||||
ndim=3, | ||||||
dim0=dims_list[0 + dims.to(int).value()], | ||||||
dim1=dims_list[1 + dims.to(int).value()], | ||||||
dim2=dims_list[2 + dims.to(int).value()], | ||||||
).to(device) | ||||||
|
||||||
z = torch.tensordot(x, y, dims=3 - dims.to(int).value()) | ||||||
return z | ||||||
|
||||||
@autotest(check_graph=True, n=1) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
def test_tensordot_tensordim(test_case): | ||||||
device = random_device() | ||||||
x = random_tensor(4, 1,3,2,5).to(device) | ||||||
y = random_tensor(4, 4,2,3,5).to(device) | ||||||
z = torch.tensordot(x, y, dims=[[1,2,0],[2,1,0]]) | ||||||
return z | ||||||
|
||||||
@autotest(check_graph=True) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
def test_tensordot_broadcast(test_case): | ||||||
device = random_device() | ||||||
x = random_tensor(4, 1,1,1,1).to(device) | ||||||
y = random_tensor(4, 2,3,4,5).to(device) | ||||||
z = torch.tensordot(x, y, dims=random(high=5).to(int).value()) | ||||||
return z | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dim_a.size() 和 dims_b.size() 要相等。那能不能把这个检查放到上面,然后写一个for循环,这样合理吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你的意思是先判断是否等长再做处理吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我改一下。不过话说回来,如果等长判断不通过的话,这个算子就报错了,不会存在影响效率的问题?但我也改一下吧,这样看起来通顺些。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改