Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor slice op #3444

Merged
merged 41 commits into from
Aug 11, 2020
Merged

Refactor slice op #3444

merged 41 commits into from
Aug 11, 2020

Conversation

leaves-zwx
Copy link
Contributor

@leaves-zwx leaves-zwx commented Aug 7, 2020

重构了 slice op,为了:

  1. 支持 cpu 上运行,支持了 gpu 上 float16 运算
  2. 修复之前存在的 bug
  3. 接口语义更加对齐 python slice
  4. 尽量处理各种不同的 dynamic 情况
  5. 翻新了测试代码,添加更多情况下的测试用例
  6. 添加 reverse op (使用 slice 间接实现) 并添加测试

@leaves-zwx leaves-zwx self-assigned this Aug 8, 2020
@leaves-zwx leaves-zwx marked this pull request as ready for review August 8, 2020 02:10
@@ -78,6 +78,10 @@ def _CheckGlobalFunctionReturnAnnotation(cls):
assert len(cls.__args__) > 0
for cls_arg in cls.__args__:
_CheckGlobalFunctionReturnAnnotation(cls_arg)
elif oft.OriginFrom(cls, typing.List):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件的修改属于另外一个 pr

@@ -342,3 +342,79 @@ def foo(
return (x,)

foo([[data]])(Test)


def test_annotation_return_List_Numpy(test_case):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件的修改属于另外一个 pr

#ifdef __CUDA_ARCH__
#pragma unroll
#endif
for (int64_t i = 0; i < params.ndim; ++i) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

params.ndim 作为模板参数才能保证最佳的效率

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu实现已将 params.ndim 作为模版参数

SliceKernelUtil<device_type, T>::Forward(ctx->device_ctx(), params, x_tensor->dptr<T>(),
y_tensor->mut_dptr<T>());
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方直接挂掉是不是更合适,理论上 slice 的语意遇到这个情况说不通

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数检查的时候发现是 empty slice 会直接返回错误

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    if (step > 0) {
      CHECK_LT_OR_RETURN(start, stop) << "slice start must be less than stop when step > 0"
                                         ", otherwise empty result will be outputted.";
    } else {
      CHECK_GT_OR_RETURN(start, stop) << "slice start must be more than stop when step < 0"
                                         ", otherwise empty result will be outputted.";
    }

CHECK_EQ_OR_RETURN(step_vec.size(), ndim);

const SbpParallel& out_sbp = ctx->SbpParallel4ArgNameAndIndex("y", 0);
if (ctx->parallel_ctx().parallel_num() != 1 && out_sbp.has_split_parallel()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这个检查并不需要
如果这个检查是必要的,那么我们基本上所有op都要加上这个检查

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

但如果 slice 的维度跟 split 在一个维度上,这种也不检查一下吗?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

但如果 slice 的维度跟 split 在一个维度上,这种也不检查一下吗?

是不是可以假定 sbp infer的时候不会产生这种配置

conv也是不支持在h或者w上面split的,按照这个思路是不是也要检查一下呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conv 可以这样假设是没错的。但在 slice 这里不检查可能会更危险,我们常常使用的是 S(0),在这种情况下,用户不小心对 dim_0 进行了 slice,将会出现未定义行为。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前出过这种错误,要定位还是挺麻烦的。

@oneflow-ci-bot oneflow-ci-bot merged commit f10d10e into master Aug 11, 2020
@oneflow-ci-bot oneflow-ci-bot deleted the dev_refactor_slice_op branch August 11, 2020 10:58
@jackalcooper jackalcooper added this to the 0.1.9 milestone Aug 13, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants