-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[Pten] Add reduce mean kernel, replace with mean API #37559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e1ea408
4facf3e
b4b0d62
1e8e06c
366a3ff
253dc18
0033c69
e67cd57
3383c3f
13423cb
8ab3a40
3b3ec9b
e2a5f4b
2e0343d
d4969df
197e62c
5e6cb33
c376481
d200fd6
33c1b3d
63e4391
788e49f
fa512f4
a3437a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,9 @@ namespace experimental { | |
|
||
// TODO(chenweihang): add scale API | ||
// TODO(chenweihang): move mean API into stat.h/cc | ||
PD_DLL_DECL Tensor mean(const Tensor& x); | ||
PD_DLL_DECL Tensor mean(const Tensor& x, | ||
const std::vector<int64_t>& axis, | ||
bool keep_dim); | ||
Comment on lines
+24
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需不需要再提供一个带有默认值的API接口? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后续PR 再加默认值的接口 |
||
|
||
PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y); | ||
|
||
|
@@ -31,5 +33,10 @@ PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y); | |
|
||
PD_DLL_DECL Tensor multiply(const Tensor& x, const Tensor& y); | ||
|
||
PD_DLL_DECL Tensor sum(const Tensor& x, | ||
const std::vector<int64_t>& axis, | ||
DataType dtype, | ||
bool keep_dim); | ||
|
||
Comment on lines
+36
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后续PR 再加默认值的接口 |
||
} // namespace experimental | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,13 +34,44 @@ DenseTensor Sign(const ContextT& dev_ctx, const DenseTensor& x) { | |
} | ||
|
||
template <typename T, typename ContextT> | ||
DenseTensor Mean(const ContextT& dev_ctx, const DenseTensor& x) { | ||
auto out_meta = ReductionInferMeta(x.meta()); | ||
DenseTensor Mean(const ContextT& dev_ctx, | ||
const DenseTensor& x, | ||
const std::vector<int64_t>& axis, | ||
bool keep_dim) { | ||
auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); | ||
const auto allocator = | ||
std::make_shared<paddle::experimental::DefaultAllocator>( | ||
dev_ctx.GetPlace()); | ||
pten::DenseTensor dense_out(allocator, out_meta); | ||
Mean<T>(dev_ctx, x, &dense_out); | ||
bool reduce_all = false; | ||
DataType out_dtype = pten::DataType::UNDEFINED; | ||
Mean<T>( | ||
dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), out_dtype, &dense_out); | ||
return dense_out; | ||
} | ||
|
||
template <typename T, typename ContextT> | ||
DenseTensor Sum(const ContextT& dev_ctx, | ||
const DenseTensor& x, | ||
const std::vector<int64_t>& axis, | ||
DataType dtype, | ||
bool keep_dim) { | ||
auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); | ||
const auto allocator = | ||
std::make_shared<paddle::experimental::DefaultAllocator>( | ||
dev_ctx.GetPlace()); | ||
pten::DenseTensor dense_out(allocator, out_meta); | ||
|
||
// The real value of reduce_all will be get in kernel | ||
// so use default value(false) is OK. | ||
bool reduce_all = false; | ||
|
||
if (x.dtype() == pten::DataType::BOOL || x.dtype() == pten::DataType::INT32 || | ||
x.dtype() == pten::DataType::INT64) { | ||
dtype = pten::DataType::INT64; | ||
} | ||
Comment on lines
+69
to
+72
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在InferMeta里看到有类似的执行逻辑,这里的逻辑能否仅放在InferMeta或者kernel中处理?如果要代码自动生成的话这类情况可能还需要单独配置,会增加配置项的复杂性 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后续PR 优化这里 |
||
|
||
Sum<T>(dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), dtype, &dense_out); | ||
return dense_out; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mean.cu中的逻辑也需要恢复
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意check下kernel注册的写法也要恢复到和原先一样
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done,thx