Skip to content

Commit f47483d

Browse files
committed
Moved to method based parameters
1 parent c2a17a1 commit f47483d

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

src/net.hpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,20 @@
44
#include <forward_list>
55

66
namespace nn{
7+
using autodiff::Var;
8+
79
class Net{
810
public:
9-
void backward(const autodiff::Var &loss);
10-
class Parameter : public autodiff::Var{
11-
public:
12-
Parameter(Net* net, const Tensor& data)
13-
: autodiff::Var(data){ net->param_list.push_front(this); }
14-
Parameter(Net* net ,size_t x, size_t y=1, size_t z=1, size_t t=1)
15-
: autodiff::Var(x,y,z,t){ net->param_list.push_front(this); }
16-
};
17-
std::forward_list<Parameter*>& params() {return param_list;}
11+
void backward(const Var &loss);
12+
13+
Var& create_parameter(const Tensor& data) {
14+
parameters.push_front(Parameter(data));
15+
return &parameters[0];
16+
}
17+
18+
std::forward_list<Var>& params() { return parameters; }
1819
protected:
19-
std::forward_list<Parameter*> param_list;
20+
std::forward_list<Var> parameters;
2021
};
2122
} // namespace nn
2223
#endif // NET_H

0 commit comments

Comments
 (0)