-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor slice op #3444
Refactor slice op #3444
Conversation
…dev_refactor_slice_op
@@ -78,6 +78,10 @@ def _CheckGlobalFunctionReturnAnnotation(cls): | |||
assert len(cls.__args__) > 0 | |||
for cls_arg in cls.__args__: | |||
_CheckGlobalFunctionReturnAnnotation(cls_arg) | |||
elif oft.OriginFrom(cls, typing.List): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件的修改属于另外一个 pr
@@ -342,3 +342,79 @@ def foo( | |||
return (x,) | |||
|
|||
foo([[data]])(Test) | |||
|
|||
|
|||
def test_annotation_return_List_Numpy(test_case): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件的修改属于另外一个 pr
oneflow/user/kernels/slice_util.h
Outdated
#ifdef __CUDA_ARCH__ | ||
#pragma unroll | ||
#endif | ||
for (int64_t i = 0; i < params.ndim; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
params.ndim 作为模板参数才能保证最佳的效率
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gpu实现已将 params.ndim 作为模版参数
…c/oneflow into dev_refactor_slice_op
SliceKernelUtil<device_type, T>::Forward(ctx->device_ctx(), params, x_tensor->dptr<T>(), | ||
y_tensor->mut_dptr<T>()); | ||
} | ||
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方直接挂掉是不是更合适,理论上 slice 的语意遇到这个情况说不通
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参数检查的时候发现是 empty slice 会直接返回错误
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (step > 0) {
CHECK_LT_OR_RETURN(start, stop) << "slice start must be less than stop when step > 0"
", otherwise empty result will be outputted.";
} else {
CHECK_GT_OR_RETURN(start, stop) << "slice start must be more than stop when step < 0"
", otherwise empty result will be outputted.";
}
…c/oneflow into dev_refactor_slice_op
CHECK_EQ_OR_RETURN(step_vec.size(), ndim); | ||
|
||
const SbpParallel& out_sbp = ctx->SbpParallel4ArgNameAndIndex("y", 0); | ||
if (ctx->parallel_ctx().parallel_num() != 1 && out_sbp.has_split_parallel()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉这个检查并不需要
如果这个检查是必要的,那么我们基本上所有op都要加上这个检查
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
但如果 slice 的维度跟 split 在一个维度上,这种也不检查一下吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
但如果 slice 的维度跟 split 在一个维度上,这种也不检查一下吗?
是不是可以假定 sbp infer的时候不会产生这种配置
conv也是不支持在h或者w上面split的,按照这个思路是不是也要检查一下呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conv 可以这样假设是没错的。但在 slice 这里不检查可能会更危险,我们常常使用的是 S(0),在这种情况下,用户不小心对 dim_0 进行了 slice,将会出现未定义行为。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前出过这种错误,要定位还是挺麻烦的。
…c/oneflow into dev_refactor_slice_op
…dev_refactor_slice_op
重构了 slice op,为了: