Skip to content

Commit 3e80253

Browse files
authored
[Move selected_rows PR #4] SelectedRows inherits from TensorBase. (PaddlePaddle#39162)
* Added selected_rows and rw_lock to pten * Renamed the unit test target to fix CI * Removed Class SelectedRows in Fluid, changed include/cmake relationship, use pten::SelectedRows in Fluid * Remove rw_lock.h,rw_lock_test.cc in fluid * Use pten::RWLock and pten::AutoRDLock, fix CI * Use pten::SelectedRows * Use pten::SelectedRows * Fix to pass NPU CI * Selected_Rows inherits from TensorBase * Use pten::SelectedRows, to pass NPU CI * To fix NPU CI * To fix NPU CI again * Use paddle/pten/core/enforce and polish code
1 parent d9acc87 commit 3e80253

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

paddle/pten/core/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base )
2424

2525
cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor)
2626
cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
27-
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim)
27+
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector pten_enforce ddim)
2828

2929
cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc)
3030
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)

paddle/pten/core/selected_rows.h

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ limitations under the License. */
2424
#include "paddle/pten/common/place.h"
2525
#include "paddle/pten/core/ddim.h"
2626
#include "paddle/pten/core/dense_tensor.h"
27+
#include "paddle/pten/core/enforce.h"
2728
#include "paddle/pten/core/utils/rw_lock.h"
2829

2930
// See Note [ Why still include the fluid headers? ]
3031
#include "paddle/fluid/framework/mixed_vector.h"
3132
#include "paddle/fluid/memory/memcpy.h"
32-
#include "paddle/fluid/platform/enforce.h"
3333

3434
namespace pten {
35-
class SelectedRows {
35+
class SelectedRows : public TensorBase,
36+
public TypeInfoTraits<TensorBase, SelectedRows> {
3637
/*
3738
* @brief We can use the SelectedRows structure to reproduce a sparse table.
3839
* A sparse table is a key-value structure that the key is an `int64_t`,
@@ -51,21 +52,19 @@ class SelectedRows {
5152
public:
5253
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
5354
: rows_(rows), height_(height) {
54-
value_.reset(new pten::DenseTensor());
55+
value_.reset(new DenseTensor());
5556
rwlock_.reset(new RWLock);
5657
}
5758

5859
SelectedRows() {
5960
height_ = 0;
60-
value_.reset(new pten::DenseTensor());
61+
value_.reset(new DenseTensor());
6162
rwlock_.reset(new RWLock);
6263
}
6364

64-
const pten::Place& place() const { return value_->place(); }
65+
const DenseTensor& value() const { return *value_; }
6566

66-
const pten::DenseTensor& value() const { return *value_; }
67-
68-
pten::DenseTensor* mutable_value() { return value_.get(); }
67+
DenseTensor* mutable_value() { return value_.get(); }
6968

7069
int64_t height() const { return height_; }
7170

@@ -109,8 +108,8 @@ class SelectedRows {
109108
* @return a list of pair which contains the non-exists key and the index in
110109
* the value
111110
*/
112-
void Get(const pten::DenseTensor& ids,
113-
pten::DenseTensor* value,
111+
void Get(const DenseTensor& ids,
112+
DenseTensor* value,
114113
bool auto_grown = false,
115114
bool is_test = false);
116115

@@ -149,14 +148,49 @@ class SelectedRows {
149148
return pten::framework::make_ddim(dims);
150149
}
151150

151+
/// \brief Returns the name of the class for type traits.
152+
/// \return The name of the class.
153+
static const char* name() { return "SelectedRows"; }
154+
155+
/// \brief Returns the number of elements contained in tensor.
156+
/// \return The number of elements contained in tensor.
157+
int64_t numel() const override { return value_->numel(); };
158+
159+
/// \brief Returns the dims of the tensor.
160+
/// \return The dims of the tensor.
161+
const DDim& dims() const noexcept override {
162+
return value_->dims();
163+
// return paddle::framework::make_ddim(dims);
164+
}
165+
166+
/// \brief Returns the data type of the tensor.
167+
/// \return The data type of the tensor.
168+
DataType dtype() const noexcept override { return value_->dtype(); }
169+
170+
/// \brief Returns the data layout of the tensor.
171+
/// \return The data layout of the tensor.
172+
DataLayout layout() const noexcept override { return value_->layout(); }
173+
174+
/// \brief Returns the data place of the tensor.
175+
/// \return The data place of the tensor.
176+
const Place& place() const override { return value_->place(); };
177+
178+
/// \brief Test whether the metadata is valid.
179+
/// \return Whether the metadata is valid.
180+
bool valid() const noexcept override { return value_->valid(); }
181+
182+
/// \brief Test whether the storage is allocated.
183+
/// return Whether the storage is allocated.
184+
bool initialized() const override { return value_->initialized(); }
185+
152186
private:
153187
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
154188
// SelectedRows are simply concated when adding together. Until a
155189
// SelectedRows add a Tensor, will the duplicate rows be handled.
156190
paddle::framework::Vector<int64_t> rows_;
157191
std::unordered_map<int64_t, int64_t>
158192
id_to_index_; // should not be used when rows_ has duplicate member
159-
std::unique_ptr<pten::DenseTensor> value_{nullptr};
193+
std::unique_ptr<DenseTensor> value_{nullptr};
160194
int64_t height_; // height indicates the underline tensor's height
161195
std::unique_ptr<RWLock> rwlock_{nullptr};
162196
};

0 commit comments

Comments
 (0)