Skip to content

Commit 3cbf8e1

Browse files
authored
Merge pull request #2833 from Franzi2114/feature/issue-2682-operands-and-partials-8
Feature/issue 2682 operands and partials 8
2 parents 95abd90 + bbdd1fb commit 3cbf8e1

File tree

6 files changed

+153
-37
lines changed

6 files changed

+153
-37
lines changed

stan/math/fwd/functor/operands_and_partials.hpp

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class ops_partials_edge<Dx, fvar<Dx>> {
2222
: partial_(0), partials_(partial_), operand_(op) {}
2323

2424
private:
25-
template <typename, typename, typename, typename, typename, typename>
25+
template <typename, typename, typename, typename, typename, typename,
26+
typename, typename, typename>
2627
friend class stan::math::operands_and_partials;
2728
const Op& operand_;
2829

@@ -62,19 +63,25 @@ class ops_partials_edge<Dx, fvar<Dx>> {
6263
* @tparam Op3 type of the third operand
6364
* @tparam Op4 type of the fourth operand
6465
* @tparam Op5 type of the fifth operand
66+
* @tparam Op6 type of the sixth operand
67+
* @tparam Op7 type of the seventh operand
68+
* @tparam Op8 type of the eighth operand
6569
* @tparam T_return_type return type of the expression. This defaults
6670
* to a template metaprogram that calculates the scalar promotion of
67-
* Op1 -- Op5
71+
* Op1 -- Op8
6872
*/
6973
template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5,
70-
typename Dx>
71-
class operands_and_partials<Op1, Op2, Op3, Op4, Op5, fvar<Dx>> {
74+
typename Op6, typename Op7, typename Op8, typename Dx>
75+
class operands_and_partials<Op1, Op2, Op3, Op4, Op5, Op6, Op7, Op8, fvar<Dx>> {
7276
public:
7377
internal::ops_partials_edge<Dx, std::decay_t<Op1>> edge1_;
7478
internal::ops_partials_edge<Dx, std::decay_t<Op2>> edge2_;
7579
internal::ops_partials_edge<Dx, std::decay_t<Op3>> edge3_;
7680
internal::ops_partials_edge<Dx, std::decay_t<Op4>> edge4_;
7781
internal::ops_partials_edge<Dx, std::decay_t<Op5>> edge5_;
82+
internal::ops_partials_edge<Dx, std::decay_t<Op6>> edge6_;
83+
internal::ops_partials_edge<Dx, std::decay_t<Op7>> edge7_;
84+
internal::ops_partials_edge<Dx, std::decay_t<Op8>> edge8_;
7885
using T_return_type = fvar<Dx>;
7986
explicit operands_and_partials(const Op1& o1) : edge1_(o1) {}
8087
operands_and_partials(const Op1& o1, const Op2& o2)
@@ -87,6 +94,35 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, fvar<Dx>> {
8794
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
8895
const Op4& o4, const Op5& o5)
8996
: edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4), edge5_(o5) {}
97+
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
98+
const Op4& o4, const Op5& o5, const Op6& o6)
99+
: edge1_(o1),
100+
edge2_(o2),
101+
edge3_(o3),
102+
edge4_(o4),
103+
edge5_(o5),
104+
edge6_(o6) {}
105+
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
106+
const Op4& o4, const Op5& o5, const Op6& o6,
107+
const Op7& o7)
108+
: edge1_(o1),
109+
edge2_(o2),
110+
edge3_(o3),
111+
edge4_(o4),
112+
edge5_(o5),
113+
edge6_(o6),
114+
edge7_(o7) {}
115+
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
116+
const Op4& o4, const Op5& o5, const Op6& o6,
117+
const Op7& o7, const Op8& o8)
118+
: edge1_(o1),
119+
edge2_(o2),
120+
edge3_(o3),
121+
edge4_(o4),
122+
edge5_(o5),
123+
edge6_(o6),
124+
edge7_(o7),
125+
edge8_(o8) {}
90126

91127
/** \ingroup type_trait
92128
* Build the node to be stored on the autodiff graph.
@@ -102,8 +138,8 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, fvar<Dx>> {
102138
* @return the value with its derivative
103139
*/
104140
T_return_type build(Dx value) {
105-
Dx deriv
106-
= edge1_.dx() + edge2_.dx() + edge3_.dx() + edge4_.dx() + edge5_.dx();
141+
Dx deriv = edge1_.dx() + edge2_.dx() + edge3_.dx() + edge4_.dx()
142+
+ edge5_.dx() + edge6_.dx() + edge7_.dx() + edge8_.dx();
107143
return T_return_type(value, deriv);
108144
}
109145
};
@@ -124,7 +160,8 @@ class ops_partials_edge<Dx, std::vector<fvar<Dx>>> {
124160
operands_(ops) {}
125161

126162
private:
127-
template <typename, typename, typename, typename, typename, typename>
163+
template <typename, typename, typename, typename, typename, typename,
164+
typename, typename, typename>
128165
friend class stan::math::operands_and_partials;
129166
const Op& operands_;
130167

@@ -150,7 +187,8 @@ class ops_partials_edge<Dx, Eigen::Matrix<fvar<Dx>, R, C>> {
150187
operands_(ops) {}
151188

152189
private:
153-
template <typename, typename, typename, typename, typename, typename>
190+
template <typename, typename, typename, typename, typename, typename,
191+
typename, typename, typename>
154192
friend class stan::math::operands_and_partials;
155193
const Op& operands_;
156194

@@ -178,7 +216,8 @@ class ops_partials_edge<Dx, std::vector<Eigen::Matrix<fvar<Dx>, R, C>>> {
178216
}
179217

180218
private:
181-
template <typename, typename, typename, typename, typename, typename>
219+
template <typename, typename, typename, typename, typename, typename,
220+
typename, typename, typename>
182221
friend class stan::math::operands_and_partials;
183222
const Op& operands_;
184223

@@ -207,7 +246,8 @@ class ops_partials_edge<Dx, std::vector<std::vector<fvar<Dx>>>> {
207246
}
208247

209248
private:
210-
template <typename, typename, typename, typename, typename, typename>
249+
template <typename, typename, typename, typename, typename, typename,
250+
typename, typename, typename>
211251
friend class stan::math::operands_and_partials;
212252
const Op& operands_;
213253

stan/math/opencl/rev/operands_and_partials.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class ops_partials_edge<double, var_value<Op>,
2727
operands_(ops) {}
2828

2929
private:
30-
template <typename, typename, typename, typename, typename, typename>
30+
template <typename, typename, typename, typename, typename, typename,
31+
typename, typename, typename>
3132
friend class stan::math::operands_and_partials;
3233
var_value<Op> operands_;
3334
static constexpr int size() noexcept { return 0; }

stan/math/prim/functor/operands_and_partials.hpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
namespace stan {
1313
namespace math {
1414
template <typename Op1 = double, typename Op2 = double, typename Op3 = double,
15-
typename Op4 = double, typename Op5 = double,
16-
typename T_return_type = return_type_t<Op1, Op2, Op3, Op4, Op5>>
15+
typename Op4 = double, typename Op5 = double, typename Op6 = double,
16+
typename Op7 = double, typename Op8 = double,
17+
typename T_return_type
18+
= return_type_t<Op1, Op2, Op3, Op4, Op5, Op6, Op7, Op8>>
1719
class operands_and_partials; // Forward declaration
1820

1921
namespace internal {
@@ -70,7 +72,8 @@ class ops_partials_edge<ViewElt, Op, require_st_arithmetic<Op>> {
7072
static constexpr int size() noexcept { return 0; } // reverse mode
7173

7274
private:
73-
template <typename, typename, typename, typename, typename, typename>
75+
template <typename, typename, typename, typename, typename, typename,
76+
typename, typename, typename>
7477
friend class stan::math::operands_and_partials;
7578
};
7679
template <typename ViewElt, typename Op>
@@ -100,7 +103,7 @@ constexpr double
100103
*
101104
* This base template is instantiated when all operands are
102105
* primitives and we don't want to calculate derivatives at
103-
* all. So all Op1 - Op5 must be arithmetic primitives
106+
* all. So all Op1 - Op8 must be arithmetic primitives
104107
* like int or double. This is controlled with the
105108
* T_return_type type parameter.
106109
*
@@ -109,12 +112,15 @@ constexpr double
109112
* @tparam Op3 type of the third operand
110113
* @tparam Op4 type of the fourth operand
111114
* @tparam Op5 type of the fifth operand
115+
* @tparam Op6 type of the sixth operand
116+
* @tparam Op7 type of the seventh operand
117+
* @tparam Op8 type of the eighth operand
112118
* @tparam T_return_type return type of the expression. This defaults
113119
* to calling a template metaprogram that calculates the scalar
114-
* promotion of Op1..Op4
120+
* promotion of Op1..Op8
115121
*/
116122
template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5,
117-
typename T_return_type>
123+
typename Op6, typename Op7, typename Op8, typename T_return_type>
118124
class operands_and_partials {
119125
public:
120126
explicit operands_and_partials(const Op1& /* op1 */) noexcept {}
@@ -126,6 +132,17 @@ class operands_and_partials {
126132
operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */,
127133
const Op3& /* op3 */, const Op4& /* op4 */,
128134
const Op5& /* op5 */) noexcept {}
135+
operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */,
136+
const Op3& /* op3 */, const Op4& /* op4 */,
137+
const Op5& /* op5 */, const Op6& /* op6 */) noexcept {}
138+
operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */,
139+
const Op3& /* op3 */, const Op4& /* op4 */,
140+
const Op5& /* op5 */, const Op6& /* op6 */,
141+
const Op7& /* op7 */) noexcept {}
142+
operands_and_partials(const Op1& /* op1 */, const Op2& /* op2 */,
143+
const Op3& /* op3 */, const Op4& /* op4 */,
144+
const Op5& /* op5 */, const Op6& /* op6 */,
145+
const Op7& /* op7 */, const Op8& /* op8 */) noexcept {}
129146

130147
/** \ingroup type_trait
131148
* Build the node to be stored on the autodiff graph.
@@ -148,6 +165,9 @@ class operands_and_partials {
148165
internal::ops_partials_edge<double, std::decay_t<Op3>> edge3_;
149166
internal::ops_partials_edge<double, std::decay_t<Op4>> edge4_;
150167
internal::ops_partials_edge<double, std::decay_t<Op5>> edge5_;
168+
internal::ops_partials_edge<double, std::decay_t<Op6>> edge6_;
169+
internal::ops_partials_edge<double, std::decay_t<Op7>> edge7_;
170+
internal::ops_partials_edge<double, std::decay_t<Op8>> edge8_;
151171
};
152172

153173
} // namespace math

stan/math/rev/functor/operands_and_partials.hpp

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class ops_partials_edge<double, var> {
3434
: partial_(0), partials_(partial_), operand_(op) {}
3535

3636
private:
37-
template <typename, typename, typename, typename, typename, typename>
37+
template <typename, typename, typename, typename, typename, typename,
38+
typename, typename, typename>
3839
friend class stan::math::operands_and_partials;
3940
var operand_;
4041
static constexpr int size() noexcept { return 1; }
@@ -109,15 +110,22 @@ inline void update_adjoints(StdVec1& x, const Vec2& y, const vari& z) {
109110
* @tparam Op3 type of the third operand
110111
* @tparam Op4 type of the fourth operand
111112
* @tparam Op5 type of the fifth operand
113+
* @tparam Op6 type of the sixth operand
114+
* @tparam Op7 type of the seventh operand
115+
* @tparam Op8 type of the eighth operand
112116
*/
113-
template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5>
114-
class operands_and_partials<Op1, Op2, Op3, Op4, Op5, var> {
117+
template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5,
118+
typename Op6, typename Op7, typename Op8>
119+
class operands_and_partials<Op1, Op2, Op3, Op4, Op5, Op6, Op7, Op8, var> {
115120
public:
116121
internal::ops_partials_edge<double, std::decay_t<Op1>> edge1_;
117122
internal::ops_partials_edge<double, std::decay_t<Op2>> edge2_;
118123
internal::ops_partials_edge<double, std::decay_t<Op3>> edge3_;
119124
internal::ops_partials_edge<double, std::decay_t<Op4>> edge4_;
120125
internal::ops_partials_edge<double, std::decay_t<Op5>> edge5_;
126+
internal::ops_partials_edge<double, std::decay_t<Op6>> edge6_;
127+
internal::ops_partials_edge<double, std::decay_t<Op7>> edge7_;
128+
internal::ops_partials_edge<double, std::decay_t<Op8>> edge8_;
121129

122130
explicit operands_and_partials(const Op1& o1) : edge1_(o1) {}
123131
operands_and_partials(const Op1& o1, const Op2& o2)
@@ -130,6 +138,35 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, var> {
130138
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
131139
const Op4& o4, const Op5& o5)
132140
: edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4), edge5_(o5) {}
141+
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
142+
const Op4& o4, const Op5& o5, const Op6& o6)
143+
: edge1_(o1),
144+
edge2_(o2),
145+
edge3_(o3),
146+
edge4_(o4),
147+
edge5_(o5),
148+
edge6_(o6) {}
149+
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
150+
const Op4& o4, const Op5& o5, const Op6& o6,
151+
const Op7& o7)
152+
: edge1_(o1),
153+
edge2_(o2),
154+
edge3_(o3),
155+
edge4_(o4),
156+
edge5_(o5),
157+
edge6_(o6),
158+
edge7_(o7) {}
159+
operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
160+
const Op4& o4, const Op5& o5, const Op6& o6,
161+
const Op7& o7, const Op8& o8)
162+
: edge1_(o1),
163+
edge2_(o2),
164+
edge3_(o3),
165+
edge4_(o4),
166+
edge5_(o5),
167+
edge6_(o6),
168+
edge7_(o7),
169+
edge8_(o8) {}
133170

134171
/** \ingroup type_trait
135172
* Build the node to be stored on the autodiff graph.
@@ -150,8 +187,11 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, var> {
150187
operand2 = edge2_.operand(), partial2 = edge2_.partial(),
151188
operand3 = edge3_.operand(), partial3 = edge3_.partial(),
152189
operand4 = edge4_.operand(), partial4 = edge4_.partial(),
153-
operand5 = edge5_.operand(),
154-
partial5 = edge5_.partial()](const auto& vi) mutable {
190+
operand5 = edge5_.operand(), partial5 = edge5_.partial(),
191+
operand6 = edge6_.operand(), partial6 = edge6_.partial(),
192+
operand7 = edge7_.operand(), partial7 = edge7_.partial(),
193+
operand8 = edge8_.operand(),
194+
partial8 = edge8_.partial()](const auto& vi) mutable {
155195
if (!is_constant<Op1>::value) {
156196
internal::update_adjoints(operand1, partial1, vi);
157197
}
@@ -167,6 +207,15 @@ class operands_and_partials<Op1, Op2, Op3, Op4, Op5, var> {
167207
if (!is_constant<Op5>::value) {
168208
internal::update_adjoints(operand5, partial5, vi);
169209
}
210+
if (!is_constant<Op6>::value) {
211+
internal::update_adjoints(operand6, partial6, vi);
212+
}
213+
if (!is_constant<Op7>::value) {
214+
internal::update_adjoints(operand7, partial7, vi);
215+
}
216+
if (!is_constant<Op8>::value) {
217+
internal::update_adjoints(operand8, partial8, vi);
218+
}
170219
});
171220
}
172221
};
@@ -186,7 +235,8 @@ class ops_partials_edge<double, std::vector<var>> {
186235
operands_(op.begin(), op.end()) {}
187236

188237
private:
189-
template <typename, typename, typename, typename, typename, typename>
238+
template <typename, typename, typename, typename, typename, typename,
239+
typename, typename, typename>
190240
friend class stan::math::operands_and_partials;
191241
Op operands_;
192242

@@ -207,7 +257,8 @@ class ops_partials_edge<double, Op, require_eigen_st<is_var, Op>> {
207257
operands_(ops) {}
208258

209259
private:
210-
template <typename, typename, typename, typename, typename, typename>
260+
template <typename, typename, typename, typename, typename, typename,
261+
typename, typename, typename>
211262
friend class stan::math::operands_and_partials;
212263
arena_t<Op> operands_;
213264
inline int size() const noexcept { return this->operands_.size(); }
@@ -228,7 +279,8 @@ class ops_partials_edge<double, var_value<Op>, require_eigen_t<Op>> {
228279
operands_(ops) {}
229280

230281
private:
231-
template <typename, typename, typename, typename, typename, typename>
282+
template <typename, typename, typename, typename, typename, typename,
283+
typename, typename, typename>
232284
friend class stan::math::operands_and_partials;
233285
var_value<Op> operands_;
234286

@@ -256,7 +308,8 @@ class ops_partials_edge<double, std::vector<Eigen::Matrix<var, R, C>>> {
256308
}
257309

258310
private:
259-
template <typename, typename, typename, typename, typename, typename>
311+
template <typename, typename, typename, typename, typename, typename,
312+
typename, typename, typename>
260313
friend class stan::math::operands_and_partials;
261314
Op operands_;
262315

@@ -286,7 +339,8 @@ class ops_partials_edge<double, std::vector<std::vector<var>>> {
286339
}
287340

288341
private:
289-
template <typename, typename, typename, typename, typename, typename>
342+
template <typename, typename, typename, typename, typename, typename,
343+
typename, typename, typename>
290344
friend class stan::math::operands_and_partials;
291345
Op operands_;
292346
inline int size() const noexcept {
@@ -311,7 +365,8 @@ class ops_partials_edge<double, std::vector<var_value<Op>>,
311365
}
312366

313367
private:
314-
template <typename, typename, typename, typename, typename, typename>
368+
template <typename, typename, typename, typename, typename, typename,
369+
typename, typename, typename>
315370
friend class stan::math::operands_and_partials;
316371
std::vector<var_value<Op>, arena_allocator<var_value<Op>>> operands_;
317372

test/unit/math/prim/functor/operands_and_partials_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ TEST(MathMetaPrim, OperandsAndPartials) {
77
operands_and_partials<double> o1(1.0);
88
operands_and_partials<double, double, double, double> o2(2.0, 3.0, 4.0, 5.0);
99

10-
// This is size 10 because of the two empty broadcast arrays in each edge
11-
EXPECT_EQ(10, sizeof(o2));
10+
// This is size 16 because of the two empty broadcast arrays in each edge
11+
EXPECT_EQ(16, sizeof(o2));
1212

1313
EXPECT_FLOAT_EQ(27.1, o1.build(27.1));
1414
}

0 commit comments

Comments
 (0)