@@ -27,20 +27,24 @@ struct RestrictPtrTraits {
2727};
2828#endif
2929
30- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t >
30+ template <
31+ typename T,
32+ size_t N,
33+ template <typename U> class PtrTraits = DefaultPtrTraits,
34+ typename index_t = int64_t >
3135class TensorAccessorBase {
32- public:
36+ public:
3337 typedef typename PtrTraits<T>::PtrType PtrType;
3438
3539 C10_HOST_DEVICE TensorAccessorBase (
3640 PtrType data_,
3741 const index_t * sizes_,
3842 const index_t * strides_)
39- : data_(data_) /* , sizes_(sizes_), strides_(strides_)*/ {
43+ : data_(data_) /* , sizes_(sizes_), strides_(strides_)*/ {
4044 // Originally, TensorAccessor is a view of sizes and strides as
4145 // these are ArrayRef instances. Until torch::stable supports
4246 // ArrayRef-like features, we store copies of sizes and strides:
43- for (auto i= 0 ; i < N; ++i) {
47+ for (auto i = 0 ; i < N; ++i) {
4448 this ->sizes_ [i] = sizes_[i];
4549 this ->strides_ [i] = strides_[i];
4650 }
@@ -52,7 +56,8 @@ class TensorAccessorBase {
5256 C10_HOST_DEVICE const PtrType data () const {
5357 return data_;
5458 }
55- protected:
59+
60+ protected:
5661 PtrType data_;
5762 /*
5863 const index_t* sizes_;
@@ -64,48 +69,65 @@ class TensorAccessorBase {
6469 index_t strides_[N];
6570};
6671
67- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t >
68- class TensorAccessor : public TensorAccessorBase <T,N,PtrTraits,index_t > {
69- public:
72+ template <
73+ typename T,
74+ size_t N,
75+ template <typename U> class PtrTraits = DefaultPtrTraits,
76+ typename index_t = int64_t >
77+ class TensorAccessor : public TensorAccessorBase <T, N, PtrTraits, index_t > {
78+ public:
7079 typedef typename PtrTraits<T>::PtrType PtrType;
7180
7281 C10_HOST_DEVICE TensorAccessor (
7382 PtrType data_,
7483 const index_t * sizes_,
7584 const index_t * strides_)
76- : TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
85+ : TensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
7786
78- C10_HOST_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](index_t i) {
79- return TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_ + this ->strides_ [0 ]*i,this ->sizes_ +1 ,this ->strides_ +1 );
87+ C10_HOST_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
88+ index_t i) {
89+ return TensorAccessor<T, N - 1 , PtrTraits, index_t >(
90+ this ->data_ + this ->strides_ [0 ] * i,
91+ this ->sizes_ + 1 ,
92+ this ->strides_ + 1 );
8093 }
8194
82- C10_HOST_DEVICE const TensorAccessor<T, N-1 , PtrTraits, index_t > operator [](index_t i) const {
83- return TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_ + this ->strides_ [0 ]*i,this ->sizes_ +1 ,this ->strides_ +1 );
95+ C10_HOST_DEVICE const TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
96+ index_t i) const {
97+ return TensorAccessor<T, N - 1 , PtrTraits, index_t >(
98+ this ->data_ + this ->strides_ [0 ] * i,
99+ this ->sizes_ + 1 ,
100+ this ->strides_ + 1 );
84101 }
85102};
86103
87- template <typename T, template <typename U> class PtrTraits , typename index_t >
88- class TensorAccessor <T,1 ,PtrTraits,index_t > : public TensorAccessorBase<T,1 ,PtrTraits,index_t > {
89- public:
104+ template <typename T, template <typename U> class PtrTraits , typename index_t >
105+ class TensorAccessor <T, 1 , PtrTraits, index_t >
106+ : public TensorAccessorBase<T, 1 , PtrTraits, index_t > {
107+ public:
90108 typedef typename PtrTraits<T>::PtrType PtrType;
91109
92110 C10_HOST_DEVICE TensorAccessor (
93111 PtrType data_,
94112 const index_t * sizes_,
95113 const index_t * strides_)
96- : TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
97- C10_HOST_DEVICE T & operator [](index_t i) {
114+ : TensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
115+ C10_HOST_DEVICE T& operator [](index_t i) {
98116 // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
99- return this ->data_ [this ->strides_ [0 ]* i];
117+ return this ->data_ [this ->strides_ [0 ] * i];
100118 }
101- C10_HOST_DEVICE const T & operator [](index_t i) const {
102- return this ->data_ [this ->strides_ [0 ]* i];
119+ C10_HOST_DEVICE const T& operator [](index_t i) const {
120+ return this ->data_ [this ->strides_ [0 ] * i];
103121 }
104122};
105123
106- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t >
124+ template <
125+ typename T,
126+ size_t N,
127+ template <typename U> class PtrTraits = DefaultPtrTraits,
128+ typename index_t = int64_t >
107129class GenericPackedTensorAccessorBase {
108- public:
130+ public:
109131 typedef typename PtrTraits<T>::PtrType PtrType;
110132 C10_HOST GenericPackedTensorAccessorBase (
111133 PtrType data_,
@@ -116,13 +138,15 @@ class GenericPackedTensorAccessorBase {
116138 std::copy (strides_, strides_ + N, std::begin (this ->strides_ ));
117139 }
118140
119- template <typename source_index_t , class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
141+ template <
142+ typename source_index_t ,
143+ class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
120144 C10_HOST GenericPackedTensorAccessorBase (
121145 PtrType data_,
122146 const source_index_t * sizes_,
123147 const source_index_t * strides_)
124148 : data_(data_) {
125- for (auto i= 0 ; i < N; ++i) {
149+ for (auto i = 0 ; i < N; ++i) {
126150 this ->sizes_ [i] = sizes_[i];
127151 this ->strides_ [i] = strides_[i];
128152 }
@@ -134,7 +158,8 @@ class GenericPackedTensorAccessorBase {
134158 C10_HOST_DEVICE const PtrType data () const {
135159 return data_;
136160 }
137- protected:
161+
162+ protected:
138163 PtrType data_;
139164 // NOLINTNEXTLINE(*c-arrays*)
140165 index_t sizes_[N];
@@ -150,68 +175,101 @@ class GenericPackedTensorAccessorBase {
150175 }
151176};
152177
153- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t >
154- class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase <T,N,PtrTraits,index_t > {
155- public:
178+ template <
179+ typename T,
180+ size_t N,
181+ template <typename U> class PtrTraits = DefaultPtrTraits,
182+ typename index_t = int64_t >
183+ class GenericPackedTensorAccessor
184+ : public GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t > {
185+ public:
156186 typedef typename PtrTraits<T>::PtrType PtrType;
157187
158188 C10_HOST GenericPackedTensorAccessor (
159189 PtrType data_,
160190 const index_t * sizes_,
161191 const index_t * strides_)
162- : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
192+ : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(
193+ data_,
194+ sizes_,
195+ strides_) {}
163196
164197 // if index_t is not int64_t, we want to have an int64_t constructor
165- template <typename source_index_t , class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
198+ template <
199+ typename source_index_t ,
200+ class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
166201 C10_HOST GenericPackedTensorAccessor (
167202 PtrType data_,
168203 const source_index_t * sizes_,
169204 const source_index_t * strides_)
170- : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
205+ : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(
206+ data_,
207+ sizes_,
208+ strides_) {}
171209
172- C10_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](index_t i) {
210+ C10_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
211+ index_t i) {
173212 index_t * new_sizes = this ->sizes_ + 1 ;
174213 index_t * new_strides = this ->strides_ + 1 ;
175- return TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_ + this ->strides_ [0 ]*i, new_sizes, new_strides);
214+ return TensorAccessor<T, N - 1 , PtrTraits, index_t >(
215+ this ->data_ + this ->strides_ [0 ] * i, new_sizes, new_strides);
176216 }
177217
178- C10_DEVICE const TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](index_t i) const {
218+ C10_DEVICE const TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
219+ index_t i) const {
179220 const index_t * new_sizes = this ->sizes_ + 1 ;
180221 const index_t * new_strides = this ->strides_ + 1 ;
181- return TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_ + this ->strides_ [0 ]*i, new_sizes, new_strides);
222+ return TensorAccessor<T, N - 1 , PtrTraits, index_t >(
223+ this ->data_ + this ->strides_ [0 ] * i, new_sizes, new_strides);
182224 }
183225};
184226
185- template <typename T, template <typename U> class PtrTraits , typename index_t >
186- class GenericPackedTensorAccessor <T,1 ,PtrTraits,index_t > : public GenericPackedTensorAccessorBase<T,1 ,PtrTraits,index_t > {
187- public:
227+ template <typename T, template <typename U> class PtrTraits , typename index_t >
228+ class GenericPackedTensorAccessor <T, 1 , PtrTraits, index_t >
229+ : public GenericPackedTensorAccessorBase<T, 1 , PtrTraits, index_t > {
230+ public:
188231 typedef typename PtrTraits<T>::PtrType PtrType;
189232 C10_HOST GenericPackedTensorAccessor (
190233 PtrType data_,
191234 const index_t * sizes_,
192235 const index_t * strides_)
193- : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
236+ : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(
237+ data_,
238+ sizes_,
239+ strides_) {}
194240
195- template <typename source_index_t , class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
241+ template <
242+ typename source_index_t ,
243+ class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
196244 C10_HOST GenericPackedTensorAccessor (
197245 PtrType data_,
198246 const source_index_t * sizes_,
199247 const source_index_t * strides_)
200- : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
248+ : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(
249+ data_,
250+ sizes_,
251+ strides_) {}
201252
202- C10_DEVICE T & operator [](index_t i) {
253+ C10_DEVICE T& operator [](index_t i) {
203254 return this ->data_ [this ->strides_ [0 ] * i];
204255 }
205256 C10_DEVICE const T& operator [](index_t i) const {
206- return this ->data_ [this ->strides_ [0 ]* i];
257+ return this ->data_ [this ->strides_ [0 ] * i];
207258 }
208-
209259};
210260
211- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
212- using PackedTensorAccessor32 = GenericPackedTensorAccessor<T, N, PtrTraits, int32_t >;
261+ template <
262+ typename T,
263+ size_t N,
264+ template <typename U> class PtrTraits = DefaultPtrTraits>
265+ using PackedTensorAccessor32 =
266+ GenericPackedTensorAccessor<T, N, PtrTraits, int32_t >;
213267
214- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
215- using PackedTensorAccessor64 = GenericPackedTensorAccessor<T, N, PtrTraits, int64_t >;
268+ template <
269+ typename T,
270+ size_t N,
271+ template <typename U> class PtrTraits = DefaultPtrTraits>
272+ using PackedTensorAccessor64 =
273+ GenericPackedTensorAccessor<T, N, PtrTraits, int64_t >;
216274
217- } // namespace torchaudio::stable
275+ } // namespace torchaudio::stable
0 commit comments