Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add oneflow.tensordot interface #7968

Merged
merged 155 commits into from
May 25, 2022
Merged
Changes from 14 commits
Commits
Show all changes
155 commits
Select commit Hold shift + click to select a range
49120fe
add tensordot functor
marigoold Apr 6, 2022
7b9fada
remove debug code
marigoold Apr 6, 2022
a9dce8f
add python wrapper
marigoold Apr 6, 2022
517f8e6
add docs
marigoold Apr 6, 2022
9d2ee97
fix bugs
marigoold Apr 7, 2022
2432572
fix bugs and decrease ndim in unittest from 5 to 4 because of kSliceM…
marigoold Apr 7, 2022
962fa05
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 7, 2022
7e81b00
add broadcast tensodot test, remove redundant code
marigoold Apr 7, 2022
a6f02ce
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 8, 2022
26ef1cd
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 11, 2022
cca8a7f
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 11, 2022
220c4b7
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 18, 2022
18ef83b
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 20, 2022
a3abefe
add ir op
marigoold Apr 20, 2022
70be1f4
remove redundant code and format code
marigoold Apr 20, 2022
5f8aaeb
add negative dimension support, add related unittest
marigoold Apr 20, 2022
d27873f
add test for tensor/tuple dims, refine python wrapper
marigoold Apr 21, 2022
8071a65
add functor for integer dims, refine unittest
marigoold Apr 21, 2022
3bd49ec
add tuple dims for unittest
marigoold Apr 21, 2022
ced6c78
replace oneflow._internel.Tensor with oneflow.Tensor
marigoold Apr 21, 2022
5f3bb2c
format code
marigoold Apr 21, 2022
71305b5
add support for single element tensor, add related unittest
marigoold Apr 21, 2022
6b07e23
add assert infomation
marigoold Apr 21, 2022
4f25207
add more docs
marigoold Apr 21, 2022
558cff7
add reference from numpy
marigoold Apr 21, 2022
dbbebd3
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 21, 2022
2fd14df
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 22, 2022
16f73b5
fix error message, fix typo in docs
marigoold Apr 22, 2022
18e6f73
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 22, 2022
972971e
refine code, format code
marigoold Apr 22, 2022
71a6bb9
fix typo
marigoold Apr 22, 2022
7cf5e0c
refine code
marigoold Apr 22, 2022
c928ef0
add eager global test, fix bug of docstr, add checklist, fix bug of l…
marigoold Apr 22, 2022
b6a9dd7
replace print warning with warnings.warn()
marigoold Apr 22, 2022
991683c
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 23, 2022
f175d29
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 23, 2022
756d93c
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 24, 2022
4869ecc
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 24, 2022
5c6b0f8
format code
marigoold Apr 25, 2022
8aea698
Merge branch 'feat-add_tensordot_op' of github.com:Oneflow-Inc/oneflo…
marigoold Apr 25, 2022
2d5e2bc
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 25, 2022
819d605
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 26, 2022
2b3e3d8
refine reminding information
marigoold Apr 26, 2022
ede5583
modify int64list to int32list in interface
marigoold Apr 26, 2022
d1cdce6
change some int64 in dims related param to int32
marigoold Apr 26, 2022
2971dae
Merge branch 'feat-add_tensordot_op' of github.com:Oneflow-Inc/oneflo…
marigoold Apr 26, 2022
247f05e
match error message with torch
marigoold Apr 26, 2022
6d640c7
format code, add check for recurring dim
marigoold Apr 26, 2022
b0077ac
add error message unittest
marigoold Apr 26, 2022
33c9453
rename variables of dims_a/b to dot_dims_a/b
marigoold Apr 26, 2022
3739493
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 27, 2022
a7fbd5e
refine code
marigoold Apr 27, 2022
ddf7e60
Merge branch 'feat-add_tensordot_op' of github.com:Oneflow-Inc/oneflo…
marigoold Apr 27, 2022
650729e
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 28, 2022
5036097
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 29, 2022
7445b4a
Merge branch 'master' into feat-add_tensordot_op
marigoold Apr 29, 2022
486c08b
Merge branch 'master' into feat-add_tensordot_op
marigoold May 1, 2022
48532d8
rename variables and format code
marigoold May 4, 2022
663a500
Merge branch 'feat-add_tensordot_op' of github.com:Oneflow-Inc/oneflo…
marigoold May 4, 2022
f497880
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 5, 2022
b1f52fe
refine code, adjust conditional judgement order
marigoold May 5, 2022
256d28b
Merge branch 'feat-add_tensordot_op' of github.com:Oneflow-Inc/oneflo…
marigoold May 5, 2022
48655ba
merge some check (check lt and ge to check_or_return
marigoold May 5, 2022
02972d7
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 5, 2022
958ccf8
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 5, 2022
1537951
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 5, 2022
a6ced7f
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 5, 2022
89c04aa
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 6, 2022
e8af06c
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 6, 2022
bd2a82f
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 6, 2022
3f38b77
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 6, 2022
f355c87
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 6, 2022
e7866f4
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 7, 2022
aa746a1
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 7, 2022
19ce92d
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 7, 2022
168b255
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 7, 2022
c7d6eae
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 7, 2022
8a67d56
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 8, 2022
22f3e21
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 8, 2022
5a9ec24
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 8, 2022
98c9fa8
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 8, 2022
e0df9ad
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 8, 2022
33b5de7
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 9, 2022
4b8c524
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 9, 2022
a25138c
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 9, 2022
231d277
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 9, 2022
dd631c1
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 10, 2022
fa5daba
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 10, 2022
58b8b47
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 10, 2022
e885bfd
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 10, 2022
59bf754
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 10, 2022
bf20b05
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 10, 2022
9d0f40f
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 11, 2022
7a74f4d
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 11, 2022
a2cee98
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 11, 2022
eb58d8c
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 11, 2022
1ff7c81
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 11, 2022
143b410
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 11, 2022
7572c88
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 12, 2022
d535414
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 12, 2022
af8cd54
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 12, 2022
712813f
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 12, 2022
5ccd4f5
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 12, 2022
aaefd02
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 13, 2022
61520ed
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 13, 2022
3d0ad3e
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 13, 2022
09b9a31
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 13, 2022
48810f6
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 13, 2022
033233e
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 14, 2022
874783d
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 14, 2022
43e92e9
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 14, 2022
addcdde
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 14, 2022
9fb5b26
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 14, 2022
663b0fe
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 15, 2022
1529761
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 15, 2022
4b0b0a8
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 15, 2022
3af98c1
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 15, 2022
9e668ef
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 15, 2022
198d72d
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 16, 2022
d9e96b2
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 16, 2022
dff3d56
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 16, 2022
1730f23
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 16, 2022
ef2f618
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 17, 2022
4883fb9
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 17, 2022
c5307f3
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 17, 2022
eb95035
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 17, 2022
a2c6e47
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 17, 2022
94152d5
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 18, 2022
32e2196
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 18, 2022
e0178b2
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 18, 2022
d2569e8
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 18, 2022
157de26
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 18, 2022
fb305a1
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 18, 2022
cc9728f
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 18, 2022
8ea6194
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 18, 2022
47280df
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 18, 2022
dfe09ca
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 18, 2022
e10f7b6
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 18, 2022
538b10c
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 19, 2022
76d2c63
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 19, 2022
390d207
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 19, 2022
55d4f60
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 19, 2022
53db592
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 19, 2022
452f4ee
auto format by CI
oneflow-ci-bot May 19, 2022
9ab61be
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 19, 2022
cc5ba4a
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 20, 2022
0b7f042
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 20, 2022
ed6158a
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 20, 2022
6f6f12a
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 20, 2022
9ffd573
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 20, 2022
83c4f7d
Merge branch 'master' into feat-add_tensordot_op
marigoold May 25, 2022
33e0d0a
reduce dims in unittest to avoid oom
marigoold May 25, 2022
b99408b
Merge branch 'master' into feat-add_tensordot_op
mergify[bot] May 25, 2022
bed895d
reduce dims in unittest to avoid oom
marigoold May 25, 2022
e26efbb
Merge branch 'feat-add_tensordot_op' of https://github.com/Oneflow-In…
marigoold May 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
@@ -148,6 +148,7 @@ oneflow
tan,
tanh,
tensor,
tensordot,
tile,
transpose,
t,
5 changes: 5 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
@@ -902,6 +902,11 @@
Double alpha=1.0) => BatchMatMul"
bind_python: True

- name: "tensordot"
signature:
"Tensor (Tensor a, Tensor b, Int64List dims_a, Int64List dims_b) => TensorDot"
bind_python: True

- name: "l1_loss"
signature: "Tensor(Tensor input, Tensor target, String reduction) => L1Loss"
bind_python: True
88 changes: 88 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
@@ -264,6 +264,93 @@ class BatchMatMulFunctor {
std::shared_ptr<OpExpr> batch_matmul_op_;
};

class TensorDotFunctor {
public:
Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& a, const std::shared_ptr<Tensor>& b,
const std::vector<int64_t>& dims_a,
const std::vector<int64_t>& dims_b) const {
CHECK_EQ_OR_RETURN(dims_a.size(), dims_b.size())
<< "dims1 and dims2 must have same size, got " << dims_a.size() << " and " << dims_b.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim_a.size() 和 dims_b.size() 要相等。那能不能把这个检查放到上面,然后写一个for循环,这样合理吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你的意思是先判断是否等长再做处理吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你的意思是先判断是否等长再做处理吗?

是的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我改一下。不过话说回来,如果等长判断不通过的话,这个算子就报错了,不会存在影响效率的问题?但我也改一下吧,这样看起来通顺些。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿, dim_a, dim_b,还未声明。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抱歉,改得着急了。已改正。


if (dims_a.empty() && dims_b.empty()) {
std::vector<int64_t> shape_sum;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方可以构造DimVector,然后更新值即可,不需要用std::vector后面又构造一次DimVector:

Suggested change
std::vector<int64_t> shape_sum;
DimVector shape_sum(a->shape()->NumAxes() + b->shape()->NumAxes());
for() {shape_sum[i] = xxx; }

shape_sum.reserve(a->shape()->NumAxes() + b->shape()->NumAxes());
for (int64_t i = 0; i < a->shape()->NumAxes(); i++) {
shape_sum.emplace_back(a->shape()->At(i));
}
for (int64_t i = 0; i < b->shape()->NumAxes(); i++) {
shape_sum.emplace_back(b->shape()->At(i));
}
auto reshape_a = JUST(Reshape(a, Shape(DimVector{-1, 1})));
auto reshape_b = JUST(Reshape(b, Shape(DimVector{1, -1})));
return JUST(Reshape(JUST(functional::MatMul(reshape_a, reshape_b, false, false, 1.0)),
Shape(DimVector(shape_sum.begin(), shape_sum.end()))));
}
std::vector<bool> if_dot_dims_a(a->shape()->NumAxes(), false);
std::vector<bool> if_dot_dims_b(b->shape()->NumAxes(), false);
for (auto i : dims_a) if_dot_dims_a[i] = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要么知道auto的具体类型,要么用const auto不然不要使用auto,可以检查一下

for (auto i : dims_b) if_dot_dims_b[i] = true;

std::vector<int32_t> broadcast_dims_a, broadcast_dims_b;
for (int64_t i = 0; i < dims_a.size(); i++) {
int64_t size_a = a->shape()->At(dims_a[i]);
int64_t size_b = b->shape()->At(dims_b[i]);
if (size_a == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个判断条件是不是再限制一下,if (size_a == 1 && size_b > 1) ?下面那个判断类似

broadcast_dims_b.emplace_back(dims_b[i]);
} else if (size_b == 1) {
broadcast_dims_a.emplace_back(dims_a[i]);
} else {
CHECK_EQ_OR_RETURN(size_a, size_b) << "The corresponding dim must be equal, got " << size_a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同,需要规范报错信息写法以及补充测试

<< " in tensor a and " << size_b << " in tensor b";
}
}
auto reduced_a = a;
auto reduced_b = b;
if (!broadcast_dims_a.empty())
reduced_a = JUST(functional::ReduceSum(a, broadcast_dims_a, true));
if (!broadcast_dims_b.empty())
reduced_b = JUST(functional::ReduceSum(b, broadcast_dims_b, true));

int64_t rsize_a = 1, rsize_b = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个名字统一驼峰吧,不能过于随意了

std::vector<int64_t> rshape_a, rshape_b;
std::vector<int32_t> pa, pb;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector如果预先长度可以确定,你可以直接构造,而不用再调用一次reserve了

pa.reserve(a->shape()->NumAxes());
pb.reserve(b->shape()->NumAxes());
std::vector<int64_t> non_dot_dims;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

361和357行的vector的类型可以统一

for (int64_t i = 0; i < a->shape()->NumAxes(); i++) {
if (!if_dot_dims_a[i]) {
non_dot_dims.emplace_back(a->shape()->At(i));
pa.emplace_back(i);
rsize_a *= reduced_a->shape()->At(i);
rshape_a.emplace_back(reduced_a->shape()->At(i));
}
}

for (auto i : dims_a) pa.emplace_back(i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些地方是类似的问题,auto滥用

for (auto i : dims_b) pb.emplace_back(i);

for (int64_t i = 0; i < b->shape()->NumAxes(); i++) {
if (!if_dot_dims_b[i]) {
non_dot_dims.emplace_back(b->shape()->At(i));
pb.emplace_back(i);
rsize_b *= reduced_b->shape()->At(i);
rshape_b.emplace_back(reduced_b->shape()->At(i));
}
}
rshape_a.insert(rshape_a.end(), rshape_b.begin(), rshape_b.end());

int64_t dshape = 1;
for (auto i : dims_a) dshape *= reduced_a->shape()->At(i);
auto permute_a =
JUST(Reshape(JUST(Permute(reduced_a, pa)), Shape(DimVector({rsize_a, dshape}))));
auto permute_b =
JUST(Reshape(JUST(Permute(reduced_b, pb)), Shape(DimVector({dshape, rsize_b}))));

return Reshape(JUST(functional::MatMul(permute_a, permute_b, false, false, 1.0)),
Shape(DimVector({rshape_a.begin(), rshape_a.end()})));
}
};

class FusedMLPFunctor {
public:
FusedMLPFunctor() {
@@ -2721,6 +2808,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::DeConv3dFunctor>("Deconv3d");
m.add_functor<impl::MatMulFunctor>("MatMul");
m.add_functor<impl::BatchMatMulFunctor>("BatchMatMul");
m.add_functor<impl::TensorDotFunctor>("TensorDot");
m.add_functor<impl::FusedMLPFunctor>("FusedMLP");
m.add_functor<impl::LayerNormFunctor>("LayerNorm");
m.add_functor<impl::LayerNormAffineFunctor>("LayerNormAffine");
20 changes: 19 additions & 1 deletion oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
@@ -211,7 +211,7 @@ class OneFlow_NormalizationAddReluBaseOp : OneFlow_BaseOp<"normalization_add_rel
#endif // GET_ONEFLOW_BASE_OP_DEFINITIONS

// Group: BINARY
// bias_add, cast_like, celu_grad, diag_grad, diagonal_grad, dot, dropout_grad, elementwise_maximum, elementwise_minimum, elu_grad, floordiv, gelu_grad, grid_sample, hardsigmoid_grad, hardshrink_grad, hardswish_grad, l1_l2_regularize_gradient, leaky_relu_grad, masked_fill, mish_grad, multiply, narrow_grad, pow, prelu, relu_grad, selu_grad, sigmoid_grad, silu_grad, softshrink_grad, threshold_grad, tf_prelu, unfold_tensor_grad, xdivy, xlogy
// bias_add, cast_like, celu_grad, diag_grad, diagonal_grad, dot, dropout_grad, elementwise_maximum, elementwise_minimum, elu_grad, floordiv, gelu_grad, grid_sample, hardsigmoid_grad, hardshrink_grad, hardswish_grad, l1_l2_regularize_gradient, leaky_relu_grad, masked_fill, mish_grad, multiply, narrow_grad, pow, prelu, relu_grad, selu_grad, sigmoid_grad, silu_grad, softshrink_grad, tensor_dot, threshold_grad, tf_prelu, unfold_tensor_grad, xdivy, xlogy
// Total: 34

#ifdef GET_ONEFLOW_BINARY_OP_DEFINITIONS
@@ -659,6 +659,24 @@ def OneFlow_SiluGradOp : OneFlow_BaseOp<"silu_grad", [NoSideEffect, DeclareOpInt
let has_data_type_infer_fn = 1;
}

def OneFlow_TensorDotOp : OneFlow_BaseOp<"tensordot", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不适用OpBuilder构建的op不需要更新这里

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

let input = (ins
OneFlow_Tensor:$x,
OneFlow_Tensor:$y
);
let output = (outs
OneFlow_Tensor:$z
);
let attrs = (ins
DefaultValuedAttr<SI32Attr, "0">:$axis_a,
DefaultValuedAttr<SI32Attr, "0">:$axis_b
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_ThresholdGradOp : OneFlow_BaseOp<"threshold_grad", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
1 change: 1 addition & 0 deletions python/oneflow/__init__.py
Original file line number Diff line number Diff line change
@@ -364,6 +364,7 @@ def atexit_hook(hook):
from oneflow.nn.modules.global_cast import to_local_op as to_local
from oneflow.nn.modules.where import where_op as where
from oneflow.nn.modules.scatter import *
from oneflow.nn.functional import tensordot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方感觉怪怪的,是不需要flow.nn.functional.tensordot接口的,为什么要走这条导入链

from oneflow.ops.stateful_ops import StatefulOp as stateful_op
from oneflow.ops.initializer_util import constant_initializer
from oneflow.ops.initializer_util import glorot_normal_initializer
1 change: 1 addition & 0 deletions python/oneflow/framework/docstr/__init__.py
Original file line number Diff line number Diff line change
@@ -68,3 +68,4 @@
from .argsort import *
from .module import *
from .util_ops import *
from .tensordot import *
54 changes: 54 additions & 0 deletions python/oneflow/framework/docstr/tensordot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow
from oneflow.framework.docstr.utils import add_docstr

add_docstr(
oneflow.tensordot,
r"""
tensordot(a, b, dims=Union[int, Tensor, Tuple]) -> Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也不符合规范,可以看下其它算子参考自pytorch文档是怎么写的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Compute tensor dot along given dimensions.

Given two tensors a and b, and dims which has two list containing dim indices, `tensordot` traverses the two
lists and calculate the tensor dot along every dim pair.

Args:
a: The input tensor to compute tensordot
b: The input tensor to compute tensordot
dims: int or array-like

Returns:
Oneflow.Tensor: The result tensor

For example:

.. code-block:: python

>>> import oneflow as flow
>>> a = flow.randn(3, 4, 5)
>>> b = flow.randn(4, 5, 6)
>>> flow.tensordot(a, b, dims=2).shape
oneflow.Size([3, 6])
>>> b = flow.randn(5, 6, 7)
>>> flow.tensordot(a, b, dims=1).shape
oneflow.Size([12, 42])
>>> b = flow.randn(3, 4, 7)
>>> flow.tensordot(a, b, dims=[[0, 1], [0, 1]]).shape
oneflow.Size([5, 7])

""",
)
1 change: 1 addition & 0 deletions python/oneflow/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
from .functional_maxpool import max_pool1d
from .functional_maxpool import max_pool2d
from .functional_maxpool import max_pool3d
from .functional_tensordot import tensordot
from oneflow._C import adaptive_avg_pool1d
from oneflow._C import adaptive_avg_pool2d
from oneflow._C import adaptive_avg_pool3d
40 changes: 40 additions & 0 deletions python/oneflow/nn/functional/functional_tensordot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow
from typing import Union, List, Tuple

def tensordot(a, b, dims: Union[int, List[List[int]]] = 2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里输入的Union好像没有考虑Tensor,而且这个python文件并不是必要的,完全可以把这个Python wrap去掉效率会更好一些。可以参考一下flow.split是如何去掉python wrap

if not isinstance(dims, (oneflow._oneflow_internal.Tensor, int, list, tuple)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise TypeError(
f"oneflow.tensordot expects dims to be one of oneflow.Tensor, int, List[List[int]] or Tuple[List[int], List[int]], but got {type(dims)}"
)

if isinstance(dims, int):
assert dims >= 0 and dims <= min(a.dim(), b.dim())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要给一点报错信息

dim_a = list(range(a.dim() - dims, a.dim()))
dim_b = list(range(dims))

elif isinstance(dims, (list, tuple, oneflow._oneflow_internal.Tensor)):
assert len(dims) == 2
dim_a = list(dims[0])
dim_b = list(dims[1])
assert (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

isinstance(dim_a[0], int)
and isinstance(dim_b[0], int)
and len(dim_a) == len(dim_b)
)

return oneflow._C.tensordot(a, b, dim_a, dim_b)
64 changes: 64 additions & 0 deletions python/oneflow/test/modules/test_tensordot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
import oneflow.unittest
from oneflow.test_utils.automated_test_util import *


@flow.unittest.skip_unless_1n1d()
class TestTensordot(flow.unittest.TestCase):
@autotest(check_graph=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@autotest(check_graph=True)
@autotest(n=5)

def test_tensordot_intdim(test_case):
device = random_device()
dims = random()
dims_list = [random().to(int).value() for i in range(dims.to(int).value() + 3)]
x = random_tensor(
ndim=3,
dim0=dims_list[0],
dim1=dims_list[1],
dim2=dims_list[2],
).to(device)
y = random_tensor(
ndim=3,
dim0=dims_list[0 + dims.to(int).value()],
dim1=dims_list[1 + dims.to(int).value()],
dim2=dims_list[2 + dims.to(int).value()],
).to(device)

z = torch.tensordot(x, y, dims=3 - dims.to(int).value())
return z

@autotest(check_graph=True, n=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@autotest(check_graph=True, n=1)
@autotest(n=1)

def test_tensordot_tensordim(test_case):
device = random_device()
x = random_tensor(4, 1,3,2,5).to(device)
y = random_tensor(4, 4,2,3,5).to(device)
z = torch.tensordot(x, y, dims=[[1,2,0],[2,1,0]])
return z

@autotest(check_graph=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@autotest(check_graph=True)
@autotest(n=5)

def test_tensordot_broadcast(test_case):
device = random_device()
x = random_tensor(4, 1,1,1,1).to(device)
y = random_tensor(4, 2,3,4,5).to(device)
z = torch.tensordot(x, y, dims=random(high=5).to(int).value())
return z


if __name__ == "__main__":
unittest.main()