Skip to content

Commit a9fdff8

Browse files
committed
remove tensor meta utils
1 parent 71d67d5 commit a9fdff8

File tree

6 files changed

+16
-44
lines changed

6 files changed

+16
-44
lines changed

paddle/fluid/framework/operator.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1867,7 +1867,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
18671867
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
18681868
} else if (std::type_index(attr.type()) ==
18691869
std::type_index(typeid(std::string))) {
1870-
op_kernel_ctx.EmplaceBackAttr(
1870+
pt_kernel_context_->EmplaceBackAttr(
18711871
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
18721872
} else {
18731873
PADDLE_THROW(platform::errors::Unimplemented(

paddle/fluid/imperative/prepared_operator.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ static void BuildDygraphPtenKernelContext(
356356
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
357357
} else if (std::type_index(attr.type()) ==
358358
std::type_index(typeid(std::string))) {
359-
op_kernel_ctx.EmplaceBackAttr(
359+
kernel_ctx->EmplaceBackAttr(
360360
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
361361
} else {
362362
PADDLE_THROW(platform::errors::Unimplemented(

paddle/pten/api/lib/creation.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Tensor full(const std::vector<int64_t>& shape,
3838

3939
// 2. Get Device Context
4040
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
41-
auto kernel_context = pten::KernelContext(*dev_ctx);
41+
auto kernel_context = pten::KernelContext(dev_ctx);
4242

4343
// 3. Auto data transform
4444
kernel_context.EmplaceBackAttr(value);

paddle/pten/api/lib/utils/tensor_utils.cc

+13-11
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,12 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst) {
133133
void ReMakePtenDenseTensor(const paddle::framework::Tensor& src,
134134
pten::DenseTensor* dst) {
135135
auto* meta = pten::CompatibleDenseTensorUtils::GetMutableMeta(dst);
136-
pten::CompatibleDenseTensorMetaUtils::SetDDim(meta, src.dims());
137-
pten::CompatibleDenseTensorMetaUtils::SetDataType(
138-
meta, pten::TransToPtenDataType(src.type()));
139-
pten::CompatibleDenseTensorMetaUtils::SetDataLayout(
140-
meta, pten::TransToPtenDataLayout(src.layout()));
136+
meta->dims = src.dims();
137+
// Since the type of DenseTensorMeta is const, const_cast must be used
138+
const_cast<DataType&>(meta->type) = pten::TransToPtenDataType(src.type());
139+
// Since the type of DenseTensorMeta is const, const_cast must be used
140+
const_cast<DataLayout&>(meta->layout) =
141+
pten::TransToPtenDataLayout(src.layout());
141142
auto* shared_storage = static_cast<SharedStorage*>(
142143
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst));
143144
PADDLE_ENFORCE_NOT_NULL(
@@ -150,12 +151,13 @@ void ReMakePtenDenseTensor(const paddle::framework::Tensor& src,
150151
void ReMakePtenDenseTensor(const paddle::framework::LoDTensor& src,
151152
pten::DenseTensor* dst) {
152153
auto* meta = pten::CompatibleDenseTensorUtils::GetMutableMeta(dst);
153-
pten::CompatibleDenseTensorMetaUtils::SetDDim(meta, src.dims());
154-
pten::CompatibleDenseTensorMetaUtils::SetDataType(
155-
meta, pten::TransToPtenDataType(src.type()));
156-
pten::CompatibleDenseTensorMetaUtils::SetDataLayout(
157-
meta, pten::TransToPtenDataLayout(src.layout()));
158-
pten::CompatibleDenseTensorMetaUtils::SetLoD(meta, src.lod());
154+
meta->dims = src.dims();
155+
// Since the type of DenseTensorMeta is const, const_cast must be used
156+
const_cast<DataType&>(meta->type) = pten::TransToPtenDataType(src.type());
157+
// Since the type of DenseTensorMeta is const, const_cast must be used
158+
const_cast<DataLayout&>(meta->layout) =
159+
pten::TransToPtenDataLayout(src.layout());
160+
SetLoD(&(meta->lod), src.lod());
159161
auto* shared_storage = static_cast<SharedStorage*>(
160162
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst));
161163
PADDLE_ENFORCE_NOT_NULL(

paddle/pten/core/compat_utils.h

-26
Original file line numberDiff line numberDiff line change
@@ -47,30 +47,4 @@ class CompatibleDenseTensorUtils {
4747
}
4848
};
4949

50-
class CompatibleDenseTensorMetaUtils {
51-
public:
52-
static void SetDDim(DenseTensorMeta* meta, const DDim& dims) {
53-
meta->dims = dims;
54-
}
55-
56-
static void SetDataType(DenseTensorMeta* meta, DataType type) {
57-
// Since the type of DenseTensorMeta is const, const_cast must be used
58-
const_cast<DataType&>(meta->type) = type;
59-
}
60-
61-
static void SetDataLayout(DenseTensorMeta* meta, DataLayout layout) {
62-
// Since the type of DenseTensorMeta is const, const_cast must be used
63-
const_cast<DataLayout&>(meta->layout) = layout;
64-
}
65-
66-
template <typename SrcLoD>
67-
static void SetLoD(DenseTensorMeta* meta, const SrcLoD& lod) {
68-
meta->lod.reserve(lod.size());
69-
meta->lod.clear();
70-
for (auto&& v : lod) {
71-
meta->lod.emplace_back(v);
72-
}
73-
}
74-
};
75-
7650
} // namespace pten

paddle/pten/core/tensor_meta.h

-4
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ namespace pten {
3131
using DDim = paddle::framework::DDim;
3232
using LoD = std::vector<std::vector<size_t>>;
3333

34-
class CompatibleDenseTensorMetaUtils;
35-
3634
/// \brief The meta data of dense tensor. Take the structure type
3735
/// and use all default operations.
3836
///
@@ -48,8 +46,6 @@ struct DenseTensorMeta {
4846
DataLayout layout,
4947
const std::vector<std::vector<size_t>>& lod);
5048

51-
friend class CompatibleDenseTensorMetaUtils;
52-
5349
/// \brief Test whether the metadata is valid. Does not throw exceptions.
5450
/// \return Whether the metadata is valid.
5551
bool valid() const noexcept;

0 commit comments

Comments
 (0)