@@ -24,15 +24,16 @@ limitations under the License. */
2424#include " paddle/pten/common/place.h"
2525#include " paddle/pten/core/ddim.h"
2626#include " paddle/pten/core/dense_tensor.h"
27+ #include " paddle/pten/core/enforce.h"
2728#include " paddle/pten/core/utils/rw_lock.h"
2829
2930// See Note [ Why still include the fluid headers? ]
3031#include " paddle/fluid/framework/mixed_vector.h"
3132#include " paddle/fluid/memory/memcpy.h"
32- #include " paddle/fluid/platform/enforce.h"
3333
3434namespace pten {
35- class SelectedRows {
35+ class SelectedRows : public TensorBase ,
36+ public TypeInfoTraits<TensorBase, SelectedRows> {
3637 /*
3738 * @brief We can use the SelectedRows structure to reproduce a sparse table.
3839 * A sparse table is a key-value structure that the key is an `int64_t`,
@@ -51,21 +52,19 @@ class SelectedRows {
5152 public:
5253 SelectedRows (const std::vector<int64_t >& rows, const int64_t & height)
5354 : rows_(rows), height_(height) {
54- value_.reset (new pten:: DenseTensor ());
55+ value_.reset (new DenseTensor ());
5556 rwlock_.reset (new RWLock);
5657 }
5758
5859 SelectedRows () {
5960 height_ = 0 ;
60- value_.reset (new pten:: DenseTensor ());
61+ value_.reset (new DenseTensor ());
6162 rwlock_.reset (new RWLock);
6263 }
6364
64- const pten::Place& place () const { return value_-> place () ; }
65+ const DenseTensor& value () const { return * value_; }
6566
66- const pten::DenseTensor& value () const { return *value_; }
67-
68- pten::DenseTensor* mutable_value () { return value_.get (); }
67+ DenseTensor* mutable_value () { return value_.get (); }
6968
7069 int64_t height () const { return height_; }
7170
@@ -109,8 +108,8 @@ class SelectedRows {
109108 * @return a list of pair which contains the non-exists key and the index in
110109 * the value
111110 */
112- void Get (const pten:: DenseTensor& ids,
113- pten:: DenseTensor* value,
111+ void Get (const DenseTensor& ids,
112+ DenseTensor* value,
114113 bool auto_grown = false ,
115114 bool is_test = false );
116115
@@ -149,14 +148,49 @@ class SelectedRows {
149148 return pten::framework::make_ddim (dims);
150149 }
151150
151+ // / \brief Returns the name of the class for type traits.
152+ // / \return The name of the class.
153+ static const char * name () { return " SelectedRows" ; }
154+
155+ // / \brief Returns the number of elements contained in tensor.
156+ // / \return The number of elements contained in tensor.
157+ int64_t numel () const override { return value_->numel (); };
158+
159+ // / \brief Returns the dims of the tensor.
160+ // / \return The dims of the tensor.
161+ const DDim& dims () const noexcept override {
162+ return value_->dims ();
163+ // return paddle::framework::make_ddim(dims);
164+ }
165+
166+ // / \brief Returns the data type of the tensor.
167+ // / \return The data type of the tensor.
168+ DataType dtype () const noexcept override { return value_->dtype (); }
169+
170+ // / \brief Returns the data layout of the tensor.
171+ // / \return The data layout of the tensor.
172+ DataLayout layout () const noexcept override { return value_->layout (); }
173+
174+ // / \brief Returns the data place of the tensor.
175+ // / \return The data place of the tensor.
176+ const Place& place () const override { return value_->place (); };
177+
178+ // / \brief Test whether the metadata is valid.
179+ // / \return Whether the metadata is valid.
180+ bool valid () const noexcept override { return value_->valid (); }
181+
182+ // / \brief Test whether the storage is allocated.
183+ // / return Whether the storage is allocated.
184+ bool initialized () const override { return value_->initialized (); }
185+
152186 private:
153187 // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
154188 // SelectedRows are simply concated when adding together. Until a
155189 // SelectedRows add a Tensor, will the duplicate rows be handled.
156190 paddle::framework::Vector<int64_t > rows_;
157191 std::unordered_map<int64_t , int64_t >
158192 id_to_index_; // should not be used when rows_ has duplicate member
159- std::unique_ptr<pten:: DenseTensor> value_{nullptr };
193+ std::unique_ptr<DenseTensor> value_{nullptr };
160194 int64_t height_; // height indicates the underline tensor's height
161195 std::unique_ptr<RWLock> rwlock_{nullptr };
162196};
0 commit comments