-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlasso_train.h
45 lines (39 loc) · 1.02 KB
/
lasso_train.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
ndarray make_patch(ndarray x, int patch_size);
class admm
{
private:
// ndarray *u;
// ndarray Z;
// ndarray x;
ndarray inv_;
ndarray w_t, z_t, h_t;
float thresh;
ndarray inv_matrix_XTy, inv_matrix_DT;
// ndarray normalize(ndarray x);
public:
// admm(ndarray &D, ndarray &X, ndarray &Y, float lambda, float rho);
ndarray D;
admm(ndarray &Y, ndarray &X, float lambda, float rho);
ndarray X;
ndarray Y;
ndarray debug;
float lambda, rho;
void fit();
void init();
virtual void generate_transform_matrix(int n_features);
ndarray soft_thresholding(ndarray &x);
ndarray hard_thresholding(ndarray &x);
void train(int iter);
ndarray get_sparse_vec();
// ~admm();
};
class fused_lasso : public admm
{
private:
/* data */
public:
fused_lasso(ndarray &Y, ndarray &X, float lambda, float rho, float sparse_coef=1.0, float fused_coef=1.0);
void generate_transform_matrix(int n_features);
float sparse_coef, fused_coef;
// ~fused_lasso();
};