Skip to content

Commit e402ecb

Browse files
committed
parameterize matrix on how it is sliced
1 parent fc1dd88 commit e402ecb

File tree

1 file changed

+165
-89
lines changed

1 file changed

+165
-89
lines changed

inst/include/cpp11/matrix.hpp

Lines changed: 165 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,94 @@
1010
#include "cpp11/sexp.hpp" // for sexp
1111

1212
namespace cpp11 {
13-
template <typename V, typename T>
14-
class matrix {
13+
14+
// matrix dimensions
15+
struct matrix_dims {
16+
protected:
17+
const int nrow_;
18+
const int ncol_;
19+
20+
public:
21+
matrix_dims(SEXP data) : nrow_(Rf_nrows(data)), ncol_(Rf_ncols(data)) {}
22+
matrix_dims(int nrow, int ncol) : nrow_(nrow), ncol_(ncol) {}
23+
24+
int nrow() const { return nrow_; }
25+
int ncol() const { return ncol_; }
26+
};
27+
28+
// base type for dimension-wise matrix access specialization
29+
struct matrix_slice {};
30+
31+
struct by_row: public matrix_slice {};
32+
struct by_column: public matrix_slice {};
33+
34+
// basic properties of matrix slices
35+
template<typename S>
36+
struct matrix_slices: public matrix_dims {
37+
public:
38+
using matrix_dims::matrix_dims;
39+
using matrix_dims::nrow;
40+
using matrix_dims::ncol;
41+
42+
int nslices() const;
43+
int slice_size() const;
44+
int slice_stride() const;
45+
int slice_offset(int pos) const;
46+
};
47+
48+
// basic properties of matrix row slices
49+
template<>
50+
struct matrix_slices<by_row>: public matrix_dims {
51+
public:
52+
using matrix_dims::matrix_dims;
53+
using matrix_dims::nrow;
54+
using matrix_dims::ncol;
55+
56+
int nslices() const { return nrow(); }
57+
int slice_size() const { return ncol(); }
58+
int slice_stride() const { return ncol(); }
59+
int slice_offset(int pos) const { return pos; }
60+
};
61+
62+
// basic properties of matrix column slices
63+
template<>
64+
struct matrix_slices<by_column>: public matrix_dims {
65+
public:
66+
using matrix_dims::matrix_dims;
67+
using matrix_dims::nrow;
68+
using matrix_dims::ncol;
69+
70+
int nslices() const { return ncol(); }
71+
int slice_size() const { return nrow(); }
72+
int slice_stride() const { return 1; }
73+
int slice_offset(int pos) const { return pos * nrow(); }
74+
};
75+
76+
template <typename V, typename T, typename S>
77+
class matrix: public matrix_slices<S> {
1578
private:
1679
V vector_;
17-
int nrow_;
1880

1981
public:
20-
class row {
82+
83+
// matrix slice: row (if S=by_row) or a column (if S=by_column)
84+
class slice {
2185
private:
2286
matrix& parent_;
23-
int row_;
87+
int index_; // slice index
88+
int offset_; // index of the first element of a slice
2489

2590
public:
26-
row(matrix& parent, R_xlen_t row) : parent_(parent), row_(row) {}
27-
T operator[](const int pos) { return parent_.vector_[row_ + (pos * parent_.nrow_)]; }
91+
slice(matrix& parent, int index) : parent_(parent), index_(index), offset_(parent.slice_offset(index)) {}
92+
inline T operator[](const int pos) const { return parent_.vector_[offset_ + stride() * pos]; }
93+
inline R_xlen_t stride() const { return parent_.slice_stride(); }
94+
inline R_xlen_t size() const { return parent_.slice_size(); }
95+
bool operator!=(const slice& rhs) { return index_ != rhs.index_; }
2896

97+
// iterates elements of a slice
2998
class iterator {
3099
private:
31-
row& row_;
100+
slice& slice_;
32101
int pos_;
33102

34103
public:
@@ -38,91 +107,71 @@ class matrix {
38107
using reference = T&;
39108
using iterator_category = std::forward_iterator_tag;
40109

41-
iterator(row& row, R_xlen_t pos) : row_(row), pos_(pos) {}
42-
iterator begin() const { return row_.parent_.vector_iterator(&this, 0); }
43-
iterator end() const { return iterator(&this, row_.size()); }
110+
iterator(slice& slice, R_xlen_t pos) : slice_(slice), pos_(pos) {}
111+
44112
inline iterator& operator++() {
45113
++pos_;
46114
return *this;
47115
}
48-
bool operator!=(const iterator& rhs) {
49-
return !(pos_ == rhs.pos_ && row_.row_ == rhs.row_.row_);
116+
inline bool operator==(const iterator& rhs) {
117+
return pos_ == rhs.pos_ && slice_.index_ == rhs.slice_.index_;
118+
}
119+
inline bool operator!=(const iterator& rhs) {
120+
return !operator==(rhs);
50121
}
51-
T operator*() const { return row_[pos_]; };
122+
T operator*() const { return slice_[pos_]; };
52123
};
53124

54125
iterator begin() { return iterator(*this, 0); }
55126
iterator end() { return iterator(*this, size()); }
56-
R_xlen_t size() const { return parent_.ncol(); }
57-
bool operator!=(const row& rhs) { return row_ != rhs.row_; }
58-
row& operator++() {
59-
++row_;
60-
return *this;
61-
}
62-
row& operator*() { return *this; }
63127
};
64-
friend row;
128+
friend slice;
65129

66-
class column {
130+
// iterates slices (rows or columns -- depending on S template param) of a matrix
131+
class iterator {
67132
private:
68-
matrix& parent_;
69-
int col_;
70-
int offset_;
133+
matrix& matrix_;
134+
int pos_;
71135

72136
public:
73-
column(matrix& parent, R_xlen_t col) : parent_(parent), col_(col), offset_(parent_.nrow() * col_) {}
74-
T operator[](const int pos) { return parent_.vector_[offset_ + pos]; }
75-
76-
class iterator {
77-
private:
78-
column& col_;
79-
int pos_;
80-
81-
public:
82-
using difference_type = std::ptrdiff_t;
83-
using value_type = T;
84-
using pointer = T*;
85-
using reference = T&;
86-
using iterator_category = std::forward_iterator_tag;
87-
88-
iterator(column& col, R_xlen_t pos) : col_(col), pos_(pos) {}
89-
iterator begin() const { return col_.parent_.vector_iterator(&this, 0); }
90-
iterator end() const { return iterator(&this, col_.size()); }
91-
inline iterator& operator++() {
92-
++pos_;
93-
return *this;
94-
}
95-
bool operator!=(const iterator& rhs) {
96-
return !(pos_ == rhs.pos_ && col_.col_ == rhs.col_.col_);
97-
}
98-
T operator*() const { return col_[pos_]; };
99-
};
100-
101-
iterator begin() { return iterator(*this, 0); }
102-
iterator end() { return iterator(*this, size()); }
103-
R_xlen_t size() const { return parent_.nrow(); }
104-
bool operator!=(const column& rhs) { return col_ != rhs.col_; }
105-
column& operator++() {
106-
++col_;
107-
return *this;
108-
}
109-
column& operator*() { return *this; }
137+
using difference_type = std::ptrdiff_t;
138+
using value_type = slice;
139+
using pointer = slice*;
140+
using reference = slice&;
141+
using iterator_category = std::forward_iterator_tag;
142+
143+
iterator(matrix& matrix, R_xlen_t pos) : matrix_(matrix), pos_(pos) {}
144+
145+
inline iterator& operator++() {
146+
++pos_;
147+
return *this;
148+
}
149+
inline bool operator==(const iterator& rhs) {
150+
return pos_ == rhs.pos_ && (&matrix_) == (&rhs.matrix_);
151+
}
152+
inline bool operator!=(const iterator& rhs) {
153+
return !operator==(rhs);
154+
}
155+
slice operator*() const { return matrix_[pos_]; };
110156
};
111-
friend column;
157+
friend iterator;
112158

113159
public:
114-
matrix(SEXP data) : vector_(data), nrow_(INTEGER_ELT(vector_.attr("dim"), 0)) {}
160+
matrix(SEXP data) : matrix_slices<S>(data), vector_(data) {}
115161

116-
template <typename V2, typename T2>
117-
matrix(const cpp11::matrix<V2, T2>& rhs) : vector_(rhs), nrow_(rhs.nrow()) {}
162+
template <typename V2, typename T2, typename S2>
163+
matrix(const cpp11::matrix<V2, T2, S2>& rhs) : matrix_slices<S>(rhs), vector_(rhs.vector_) {}
118164

119-
matrix(int nrow, int ncol) : vector_(R_xlen_t(nrow * ncol)), nrow_(nrow) {
120-
vector_.attr("dim") = {nrow, ncol};
165+
matrix(int nrow, int ncol) : matrix_slices<S>(nrow, ncol), vector_(R_xlen_t(nrow * ncol)) {
166+
vector_.attr(R_DimSymbol) = {nrow, ncol};
121167
}
122168

123-
int nrow() const { return nrow_; }
124-
125-
int ncol() const { return size() / nrow_; }
169+
using matrix_slices<S>::nrow;
170+
using matrix_slices<S>::ncol;
171+
using matrix_slices<S>::nslices;
172+
using matrix_slices<S>::slice_size;
173+
using matrix_slices<S>::slice_stride;
174+
using matrix_slices<S>::slice_offset;
126175

127176
SEXP data() const { return vector_.data(); }
128177

@@ -140,29 +189,56 @@ class matrix {
140189

141190
r_vector<r_string> names() const { return SEXP(vector_.names()); }
142191

143-
row row_at(const int pos) { return {*this, pos}; }
144-
column column_at(const int pos) { return {*this, pos}; }
192+
T at(int row, int col) const { return vector_[row + (col * nrow())]; }
193+
T operator()(int row, int col) const { return at(row, col); }
194+
slice operator[](int index) { return slice(*this, index); }
145195

146-
T at(int row, int col) { return vector_[row + (col * nrow_)]; }
147-
T operator()(int row, int col) { return at(row, col); }
196+
iterator begin() { return {*this, 0}; }
197+
iterator end() { return {*this, nslices()}; }
198+
};
148199

149-
row rows_begin() { return {*this, 0}; }
150-
row rows_end() { return {*this, nrow()}; }
200+
template<typename S>
201+
using doubles_matrix = matrix<r_vector<double>, double, S>;
202+
template<typename S>
203+
using integers_matrix = matrix<r_vector<int>, int, S>;
204+
template<typename S>
205+
using logicals_matrix = matrix<r_vector<r_bool>, r_bool, S>;
206+
template<typename S>
207+
using strings_matrix = matrix<r_vector<r_string>, r_string, S>;
151208

152-
column columns_begin() { return {*this, 0}; }
153-
column columns_end() { return {*this, ncol()}; }
154-
};
209+
using doubles_row_matrix = doubles_matrix<by_row>;
210+
using doubles_column_matrix = doubles_matrix<by_column>;
211+
212+
using integers_row_matrix = integers_matrix<by_row>;
213+
using integers_column_matrix = integers_matrix<by_column>;
214+
215+
using logicals_row_matrix = logicals_matrix<by_row>;
216+
using logicals_column_matrix = logicals_matrix<by_column>;
155217

156-
using doubles_matrix = matrix<r_vector<double>, double>;
157-
using integers_matrix = matrix<r_vector<int>, int>;
158-
using logicals_matrix = matrix<r_vector<r_bool>, r_bool>;
159-
using strings_matrix = matrix<r_vector<r_string>, r_string>;
218+
using strings_row_matrix = strings_matrix<by_row>;
219+
using strings_column_matrix = strings_matrix<by_column>;
160220

161221
namespace writable {
162-
using doubles_matrix = matrix<r_vector<double>, r_vector<double>::proxy>;
163-
using integers_matrix = matrix<r_vector<int>, r_vector<int>::proxy>;
164-
using logicals_matrix = matrix<r_vector<r_bool>, r_vector<r_bool>::proxy>;
165-
using strings_matrix = matrix<r_vector<r_string>, r_vector<r_string>::proxy>;
222+
template<typename S>
223+
using doubles_matrix = matrix<r_vector<double>, r_vector<double>::proxy, S>;
224+
template<typename S>
225+
using integers_matrix = matrix<r_vector<int>, r_vector<int>::proxy, S>;
226+
template<typename S>
227+
using logicals_matrix = matrix<r_vector<r_bool>, r_vector<r_bool>::proxy, S>;
228+
template<typename S>
229+
using strings_matrix = matrix<r_vector<r_string>, r_vector<r_string>::proxy, S>;
230+
231+
using doubles_row_matrix = doubles_matrix<by_row>;
232+
using doubles_column_matrix = doubles_matrix<by_column>;
233+
234+
using integers_row_matrix = integers_matrix<by_row>;
235+
using integers_column_matrix = integers_matrix<by_column>;
236+
237+
using logicals_row_matrix = logicals_matrix<by_row>;
238+
using logicals_column_matrix = logicals_matrix<by_column>;
239+
240+
using strings_row_matrix = strings_matrix<by_row>;
241+
using strings_column_matrix = strings_matrix<by_column>;
166242
} // namespace writable
167243

168244
// TODO: Add tests for Matrix class

0 commit comments

Comments
 (0)