-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add oneflow.tensordot interface #7968
Conversation
@@ -659,6 +659,24 @@ def OneFlow_SiluGradOp : OneFlow_BaseOp<"silu_grad", [NoSideEffect, DeclareOpInt | |||
let has_data_type_infer_fn = 1; | |||
} | |||
|
|||
def OneFlow_TensorDotOp : OneFlow_BaseOp<"tensordot", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不适用OpBuilder构建的op不需要更新这里
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
|
||
|
||
def tensordot(a, b, dims: Union[int, List[List[int]]] = 2): | ||
if not isinstance(dims, (oneflow._oneflow_internal.Tensor, int, list, tuple)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7968/ |
CI failed when running job: cuda-module. PR label automerge has been removed |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7968/ |
CI failed when running job: cpu-module. PR label automerge has been removed |
CI failed when running job: cuda-module. PR label automerge has been removed |
Speed stats:
|
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7968/ |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7968/ |
此PR完成了:
算子实现的方法是在functor层,执行tensor permute -> tensor reshape -> functional::matmul -> tensor reshape的操作。
文档截图:
这里没写公式,因为我觉得那个非常难理解,而是加了一段代码的演示,告诉用户tensordot和哪一系列基本操作等价。
需要注意的地方: