-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev logical_xor module #5694
Dev logical_xor module #5694
Conversation
zfu82
commented
Aug 2, 2021
•
edited
Loading
edited
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
添加API文档截图,参考 #5636
logical_and() -> Tensor | ||
|
||
See :func:`oneflow.logical_and` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logical_and -> logical_xor
|
||
>>> input1 = flow.Tensor(np.array([1, 0, 1]).astype(np.float32), dtype=flow.float32) | ||
>>> input2 = flow.Tensor(np.array([1, 1, 0]).astype(np.float32), dtype=flow.float32) | ||
>>> out = flow.logical_and(input1, input2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logical_and -> logical_xor
运行这个文件看下是否有报错
def logical_or_op_tensor(input, other): | ||
""" | ||
logical_or() -> Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logical_or -> logical_xor
See :func:`oneflow.logical_xor` | ||
|
||
""" | ||
return LogicalXor()(input, other) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前因为功能的局限,才不得不用实例化 Module 的方式调用。现在有了 flow.F
下的 functional call,尽量直接调用。
可以参考这个PR 进行修改 #5754
@@ -128,6 +128,7 @@ def custom_exit(returncode): | |||
import oneflow.compatible.single_client.nn.modules.greater_equal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
single_client 下的文件不用修改,都复原吧。
…der single_client folder
template<typename T> | ||
struct BinaryFuncOR final { | ||
static OF_DEVICE_FUNC const int8_t Invoke(const T x, const T y) { return x || y; } | ||
}; | ||
template<typename T> | ||
struct BinaryFuncAny final { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 struct 是不是不需要了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是从上面移到下面了,还是需要的
python/oneflow/__init__.py
Outdated
@@ -141,6 +141,7 @@ def Sync(): | |||
import oneflow.nn.modules.greater_equal | |||
import oneflow.nn.modules.logical_and | |||
import oneflow.nn.modules.logical_or | |||
import oneflow.nn.modules.logical_xor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是当作 package 导入,没有必要,顺便把上面的 logical_and
,logical_or
,logical_xor
也给删了吧。
|
||
""" | ||
if other.dtype != input.dtype: | ||
other = flow.cast(other, input.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在这里加个以shape 的检查吧。类似
assert input.shape == other.shape, "shape of input and other should be same"
因为你这个算子内部是支持 broadcast 的,但是文档和 pytorch 对齐,是element-wise的。
或者,如果想作为 pytorch 超集,就在文档钟说明支持 broadcast。并且在 doctest、test 里添加相关的例子。
Speed stats:
|