-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[PTen] Add variable transform to/from ptenTensor and add cast kernel #36916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PTen] Add variable transform to/from ptenTensor and add cast kernel #36916
Conversation
Thanks for your contribution! |
Fix cast kernel refactor bugs
paddle/fluid/framework/tensor.cc
Outdated
@@ -209,5 +209,7 @@ void Tensor::ResetHolderWithType(std::shared_ptr<memory::Allocation> holder, | |||
type_ = type; | |||
} | |||
|
|||
void Tensor::setType(const proto::VarType::Type type) { type_ = type; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setType -> set_type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -552,14 +552,13 @@ class Reshape2Op : public ReshapeOp { | |||
const framework::ExecutionContext &ctx) const override { | |||
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor"); | |||
if (multi_inputs.size() > 0) { | |||
return framework::KernelSignature( | |||
"reshape2.mulhost.mid", {"X", "ShapeTensor"}, {}, {"XShape", "Out"}); | |||
return framework::KernelSignature("reshape2.mulhost", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么XShape输出可以删除?这里mul和host后缀最好拆分一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和负责的rd 沟通过,先保持现状
@@ -21,6 +21,8 @@ namespace experimental { | |||
|
|||
PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis); | |||
|
|||
Tensor cast(const Tensor& x, DataType out_dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要增加PD_DLL_DECL声明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/pten/common/data_type.h
Outdated
#define PTEN_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \ | ||
PTEN_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__) | ||
|
||
#define PTEN_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里建议直接复用ext/dispatch.h中的实现,或者在ext/dispatch.h新增宏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/pten/infermeta/unary.cc
Outdated
@@ -74,6 +74,12 @@ DenseTensorMeta FlattenInferShape(const DenseTensorMeta& x_meta, | |||
return return_meta; | |||
} | |||
|
|||
DenseTensorMeta CastInferShape(const DenseTensorMeta& x_meta, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议直接命名为InferMeta吧,后面也需要统一改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/fluid/framework/operator.cc
Outdated
size_t start_idx = | ||
(i == 0 ? 0 : pt_kernel_context_->InputRangeAt(i - 1).second); | ||
size_t end_idx = start_idx + ins_vector.size(); | ||
|
||
if (pt_kernel_context_->InputsSize() == start_idx) { | ||
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs; | ||
for (auto* var : ins_vector) { | ||
tmp_inputs.emplace_back( | ||
experimental::MakePtenTensorBaseFromVar(*var, in_def)); | ||
} | ||
pt_kernel_context_->EmplaceBackInputs(std::move(tmp_inputs)); | ||
} else { | ||
} else if (pt_kernel_context_->InputsSize() > start_idx) { | ||
size_t input_size = pt_kernel_context_->InputsSize(); | ||
for (size_t j = 0; j < ins_vector.size(); ++j) { | ||
if (input_size > i + j) { | ||
if (input_size > start_idx + j) { | ||
experimental::ReMakePtenDenseTensorFromVar( | ||
*ins_vector[j], in_def, | ||
pt_kernel_context_->MutableInputAt<pten::DenseTensor>(i + j)); | ||
pt_kernel_context_->MutableInputAt<pten::DenseTensor>(start_idx + | ||
j)); | ||
} else { | ||
pt_kernel_context_->EmplaceBackInputWithoutSetRange( | ||
experimental::MakePtenTensorBaseFromVar(*ins_vector[j], in_def)); | ||
} | ||
// TODO(chenweihang): adapt multi-input case later | ||
} | ||
pt_kernel_context_->MutableInputRangeAt(i) = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分逻辑建议加上详细注释,方便代码阅读理解
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经添加注释
paddle/pten/api/lib/manipulation.cc
Outdated
@@ -60,6 +60,40 @@ PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis) { | |||
return out; | |||
} | |||
|
|||
Tensor cast(const Tensor& x, DataType out_dtype) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要增加PD_DLL_DECL声明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/fluid/framework/operator.cc
Outdated
"error start index when trying to set new tensor to inputs, start " | ||
"index is `%d`, but current pt_kernel_context_.inputs.size() is " | ||
"`%d` ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
报错信息首字母大写,结尾加.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/fluid/framework/operator.cc
Outdated
"error start index when trying to set new tensor to inputs, start " | ||
"index is `%d`, but current pt_kernel_context_.outputs.size() is " | ||
"`%d` ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"error start index when trying to set new tensor to inputs, start " | ||
"index is `%d`, but current pt_kernel_context_.inputs.size() is " | ||
"`%d` ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
报错信息首字母大写,结尾加.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"error start index when trying to set new tensor to inputs, start " | ||
"index is `%d`, but current pt_kernel_context_.outputs.size() is " | ||
"`%d` ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
报错信息首字母大写,结尾加.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
auto* tensor = variable->GetMutable<framework::SelectedRows>(); | ||
auto dtype = pten::TransToProtoVarType(src->dtype()); | ||
|
||
if (tensor->value().IsInitialized()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if内部没有处理逻辑是否可以和else合并?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经合并
ReshapeFromVectorValImpl(dev_ctx, x, shape, out, false); | ||
auto out_meta = InferShapeFromVecValue(x.meta(), shape); | ||
if (&x == out) { | ||
LOG(INFO) << "out_meta dims:" << out_meta.dims; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LOG(INFO)是调试加的代码吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
paddle/fluid/framework/tensor.cc
Outdated
@@ -209,5 +209,7 @@ void Tensor::ResetHolderWithType(std::shared_ptr<memory::Allocation> holder, | |||
type_ = type; | |||
} | |||
|
|||
void Tensor::set_type(const proto::VarType::Type type) { type_ = type; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const proto::VarType::Type -> const proto::VarType::Type&
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -1901,5 +1961,26 @@ void OperatorWithKernel::BuildPtenKernelContext( | |||
} | |||
} | |||
|
|||
void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const { | |||
// auto& input_names = std::get<0>(pt_kernel_signature_->args); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是无用的注释,建议在下个PR移除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
PR types
New features
PR changes
Others
Describe
本PR 主要改动点如下: