10
10
#include " cpp11/sexp.hpp" // for sexp
11
11
12
12
namespace cpp11 {
13
- template <typename V, typename T>
14
- class matrix {
13
+
14
+ // matrix dimensions
15
+ struct matrix_dims {
16
+ protected:
17
+ const int nrow_;
18
+ const int ncol_;
19
+
20
+ public:
21
+ matrix_dims (SEXP data) : nrow_(Rf_nrows(data)), ncol_(Rf_ncols(data)) {}
22
+ matrix_dims (int nrow, int ncol) : nrow_(nrow), ncol_(ncol) {}
23
+
24
+ int nrow () const { return nrow_; }
25
+ int ncol () const { return ncol_; }
26
+ };
27
+
28
+ // base type for dimension-wise matrix access specialization
29
+ struct matrix_slice {};
30
+
31
+ struct by_row : public matrix_slice {};
32
+ struct by_column : public matrix_slice {};
33
+
34
+ // basic properties of matrix slices
35
+ template <typename S>
36
+ struct matrix_slices : public matrix_dims {
37
+ public:
38
+ using matrix_dims::matrix_dims;
39
+ using matrix_dims::nrow;
40
+ using matrix_dims::ncol;
41
+
42
+ int nslices () const ;
43
+ int slice_size () const ;
44
+ int slice_stride () const ;
45
+ int slice_offset (int pos) const ;
46
+ };
47
+
48
+ // basic properties of matrix row slices
49
+ template <>
50
+ struct matrix_slices <by_row>: public matrix_dims {
51
+ public:
52
+ using matrix_dims::matrix_dims;
53
+ using matrix_dims::nrow;
54
+ using matrix_dims::ncol;
55
+
56
+ int nslices () const { return nrow (); }
57
+ int slice_size () const { return ncol (); }
58
+ int slice_stride () const { return ncol (); }
59
+ int slice_offset (int pos) const { return pos; }
60
+ };
61
+
62
+ // basic properties of matrix column slices
63
+ template <>
64
+ struct matrix_slices <by_column>: public matrix_dims {
65
+ public:
66
+ using matrix_dims::matrix_dims;
67
+ using matrix_dims::nrow;
68
+ using matrix_dims::ncol;
69
+
70
+ int nslices () const { return ncol (); }
71
+ int slice_size () const { return nrow (); }
72
+ int slice_stride () const { return 1 ; }
73
+ int slice_offset (int pos) const { return pos * nrow (); }
74
+ };
75
+
76
+ template <typename V, typename T, typename S>
77
+ class matrix : public matrix_slices <S> {
15
78
private:
16
79
V vector_;
17
- int nrow_;
18
80
19
81
public:
20
- class row {
82
+
83
+ // matrix slice: row (if S=by_row) or a column (if S=by_column)
84
+ class slice {
21
85
private:
22
86
matrix& parent_;
23
- int row_;
87
+ int index_; // slice index
88
+ int offset_; // index of the first element of a slice
24
89
25
90
public:
26
- row (matrix& parent, R_xlen_t row) : parent_(parent), row_(row) {}
27
- T operator [](const int pos) { return parent_.vector_ [row_ + (pos * parent_.nrow_ )]; }
91
+ slice (matrix& parent, int index) : parent_(parent), index_(index), offset_(parent.slice_offset(index)) {}
92
+ inline T operator [](const int pos) const { return parent_.vector_ [offset_ + stride () * pos]; }
93
+ inline R_xlen_t stride () const { return parent_.slice_stride (); }
94
+ inline R_xlen_t size () const { return parent_.slice_size (); }
95
+ bool operator !=(const slice& rhs) { return index_ != rhs.index_ ; }
28
96
97
+ // iterates elements of a slice
29
98
class iterator {
30
99
private:
31
- row& row_ ;
100
+ slice& slice_ ;
32
101
int pos_;
33
102
34
103
public:
@@ -38,91 +107,71 @@ class matrix {
38
107
using reference = T&;
39
108
using iterator_category = std::forward_iterator_tag;
40
109
41
- iterator (row& row, R_xlen_t pos) : row_(row), pos_(pos) {}
42
- iterator begin () const { return row_.parent_ .vector_iterator (&this , 0 ); }
43
- iterator end () const { return iterator (&this , row_.size ()); }
110
+ iterator (slice& slice, R_xlen_t pos) : slice_(slice), pos_(pos) {}
111
+
44
112
inline iterator& operator ++() {
45
113
++pos_;
46
114
return *this ;
47
115
}
48
- bool operator !=(const iterator& rhs) {
49
- return !(pos_ == rhs.pos_ && row_.row_ == rhs.row_ .row_ );
116
+ inline bool operator ==(const iterator& rhs) {
117
+ return pos_ == rhs.pos_ && slice_.index_ == rhs.slice_ .index_ ;
118
+ }
119
+ inline bool operator !=(const iterator& rhs) {
120
+ return !operator ==(rhs);
50
121
}
51
- T operator *() const { return row_ [pos_]; };
122
+ T operator *() const { return slice_ [pos_]; };
52
123
};
53
124
54
125
iterator begin () { return iterator (*this , 0 ); }
55
126
iterator end () { return iterator (*this , size ()); }
56
- R_xlen_t size () const { return parent_.ncol (); }
57
- bool operator !=(const row& rhs) { return row_ != rhs.row_ ; }
58
- row& operator ++() {
59
- ++row_;
60
- return *this ;
61
- }
62
- row& operator *() { return *this ; }
63
127
};
64
- friend row ;
128
+ friend slice ;
65
129
66
- class column {
130
+ // iterates slices (rows or columns -- depending on S template param) of a matrix
131
+ class iterator {
67
132
private:
68
- matrix& parent_;
69
- int col_;
70
- int offset_;
133
+ matrix& matrix_;
134
+ int pos_;
71
135
72
136
public:
73
- column (matrix& parent, R_xlen_t col) : parent_(parent), col_(col), offset_(parent_.nrow() * col_) {}
74
- T operator [](const int pos) { return parent_.vector_ [offset_ + pos]; }
75
-
76
- class iterator {
77
- private:
78
- column& col_;
79
- int pos_;
80
-
81
- public:
82
- using difference_type = std::ptrdiff_t ;
83
- using value_type = T;
84
- using pointer = T*;
85
- using reference = T&;
86
- using iterator_category = std::forward_iterator_tag;
87
-
88
- iterator (column& col, R_xlen_t pos) : col_(col), pos_(pos) {}
89
- iterator begin () const { return col_.parent_ .vector_iterator (&this , 0 ); }
90
- iterator end () const { return iterator (&this , col_.size ()); }
91
- inline iterator& operator ++() {
92
- ++pos_;
93
- return *this ;
94
- }
95
- bool operator !=(const iterator& rhs) {
96
- return !(pos_ == rhs.pos_ && col_.col_ == rhs.col_ .col_ );
97
- }
98
- T operator *() const { return col_[pos_]; };
99
- };
100
-
101
- iterator begin () { return iterator (*this , 0 ); }
102
- iterator end () { return iterator (*this , size ()); }
103
- R_xlen_t size () const { return parent_.nrow (); }
104
- bool operator !=(const column& rhs) { return col_ != rhs.col_ ; }
105
- column& operator ++() {
106
- ++col_;
107
- return *this ;
108
- }
109
- column& operator *() { return *this ; }
137
+ using difference_type = std::ptrdiff_t ;
138
+ using value_type = slice;
139
+ using pointer = slice*;
140
+ using reference = slice&;
141
+ using iterator_category = std::forward_iterator_tag;
142
+
143
+ iterator (matrix& matrix, R_xlen_t pos) : matrix_(matrix), pos_(pos) {}
144
+
145
+ inline iterator& operator ++() {
146
+ ++pos_;
147
+ return *this ;
148
+ }
149
+ inline bool operator ==(const iterator& rhs) {
150
+ return pos_ == rhs.pos_ && (&matrix_) == (&rhs.matrix_ );
151
+ }
152
+ inline bool operator !=(const iterator& rhs) {
153
+ return !operator ==(rhs);
154
+ }
155
+ slice operator *() const { return matrix_[pos_]; };
110
156
};
111
- friend column ;
157
+ friend iterator ;
112
158
113
159
public:
114
- matrix (SEXP data) : vector_ (data), nrow_(INTEGER_ELT( vector_.attr( " dim " ), 0 ) ) {}
160
+ matrix (SEXP data) : matrix_slices<S> (data), vector_(data ) {}
115
161
116
- template <typename V2, typename T2>
117
- matrix (const cpp11::matrix<V2, T2>& rhs) : vector_ (rhs), nrow_ (rhs.nrow() ) {}
162
+ template <typename V2, typename T2, typename S2 >
163
+ matrix (const cpp11::matrix<V2, T2, S2 >& rhs) : matrix_slices<S> (rhs), vector_ (rhs.vector_ ) {}
118
164
119
- matrix (int nrow, int ncol) : vector_(R_xlen_t(nrow * ncol)), nrow_(nrow ) {
120
- vector_.attr (" dim " ) = {nrow, ncol};
165
+ matrix (int nrow, int ncol) : matrix_slices<S>(nrow, ncol), vector_(R_xlen_t(nrow * ncol)) {
166
+ vector_.attr (R_DimSymbol ) = {nrow, ncol};
121
167
}
122
168
123
- int nrow () const { return nrow_; }
124
-
125
- int ncol () const { return size () / nrow_; }
169
+ using matrix_slices<S>::nrow;
170
+ using matrix_slices<S>::ncol;
171
+ using matrix_slices<S>::nslices;
172
+ using matrix_slices<S>::slice_size;
173
+ using matrix_slices<S>::slice_stride;
174
+ using matrix_slices<S>::slice_offset;
126
175
127
176
SEXP data () const { return vector_.data (); }
128
177
@@ -140,29 +189,56 @@ class matrix {
140
189
141
190
r_vector<r_string> names () const { return SEXP (vector_.names ()); }
142
191
143
- row row_at (const int pos) { return {*this , pos}; }
144
- column column_at (const int pos) { return {*this , pos}; }
192
+ T at (int row, int col) const { return vector_[row + (col * nrow ())]; }
193
+ T operator ()(int row, int col) const { return at (row, col); }
194
+ slice operator [](int index) { return slice (*this , index); }
145
195
146
- T at (int row, int col) { return vector_[row + (col * nrow_)]; }
147
- T operator ()(int row, int col) { return at (row, col); }
196
+ iterator begin () { return {*this , 0 }; }
197
+ iterator end () { return {*this , nslices ()}; }
198
+ };
148
199
149
- row rows_begin () { return {*this , 0 }; }
150
- row rows_end () { return {*this , nrow ()}; }
200
+ template <typename S>
201
+ using doubles_matrix = matrix<r_vector<double >, double , S>;
202
+ template <typename S>
203
+ using integers_matrix = matrix<r_vector<int >, int , S>;
204
+ template <typename S>
205
+ using logicals_matrix = matrix<r_vector<r_bool>, r_bool, S>;
206
+ template <typename S>
207
+ using strings_matrix = matrix<r_vector<r_string>, r_string, S>;
151
208
152
- column columns_begin () { return {*this , 0 }; }
153
- column columns_end () { return {*this , ncol ()}; }
154
- };
209
+ using doubles_row_matrix = doubles_matrix<by_row>;
210
+ using doubles_column_matrix = doubles_matrix<by_column>;
211
+
212
+ using integers_row_matrix = integers_matrix<by_row>;
213
+ using integers_column_matrix = integers_matrix<by_column>;
214
+
215
+ using logicals_row_matrix = logicals_matrix<by_row>;
216
+ using logicals_column_matrix = logicals_matrix<by_column>;
155
217
156
- using doubles_matrix = matrix<r_vector<double >, double >;
157
- using integers_matrix = matrix<r_vector<int >, int >;
158
- using logicals_matrix = matrix<r_vector<r_bool>, r_bool>;
159
- using strings_matrix = matrix<r_vector<r_string>, r_string>;
218
+ using strings_row_matrix = strings_matrix<by_row>;
219
+ using strings_column_matrix = strings_matrix<by_column>;
160
220
161
221
namespace writable {
162
- using doubles_matrix = matrix<r_vector<double >, r_vector<double >::proxy>;
163
- using integers_matrix = matrix<r_vector<int >, r_vector<int >::proxy>;
164
- using logicals_matrix = matrix<r_vector<r_bool>, r_vector<r_bool>::proxy>;
165
- using strings_matrix = matrix<r_vector<r_string>, r_vector<r_string>::proxy>;
222
+ template <typename S>
223
+ using doubles_matrix = matrix<r_vector<double >, r_vector<double >::proxy, S>;
224
+ template <typename S>
225
+ using integers_matrix = matrix<r_vector<int >, r_vector<int >::proxy, S>;
226
+ template <typename S>
227
+ using logicals_matrix = matrix<r_vector<r_bool>, r_vector<r_bool>::proxy, S>;
228
+ template <typename S>
229
+ using strings_matrix = matrix<r_vector<r_string>, r_vector<r_string>::proxy, S>;
230
+
231
+ using doubles_row_matrix = doubles_matrix<by_row>;
232
+ using doubles_column_matrix = doubles_matrix<by_column>;
233
+
234
+ using integers_row_matrix = integers_matrix<by_row>;
235
+ using integers_column_matrix = integers_matrix<by_column>;
236
+
237
+ using logicals_row_matrix = logicals_matrix<by_row>;
238
+ using logicals_column_matrix = logicals_matrix<by_column>;
239
+
240
+ using strings_row_matrix = strings_matrix<by_row>;
241
+ using strings_column_matrix = strings_matrix<by_column>;
166
242
} // namespace writable
167
243
168
244
// TODO: Add tests for Matrix class
0 commit comments