30
30
# ```
31
31
32
32
33
-
34
33
# This tutorial uses the following packages
35
34
36
35
using JuMP
@@ -53,29 +52,36 @@ b = rand()
53
52
X = randn (N)
54
53
Y = w * X .+ b + 0.8 * randn (N);
55
54
56
- # The helper method `fitRidge` defines and solves the corresponding model.
55
+ # The helper method `fit_ridge` defines and solves the corresponding model.
56
+ # The ridge regression is modeled with quadratic programming
57
+ # (quadratic objective and linear constraints) and solved in generic methods
58
+ # of OSQP. This is not the standard way of solving the ridge regression problem
59
+ # this is done here for didactic purposes.
57
60
58
- function fitRidge (X, Y, alpha = 0.1 )
61
+ function fit_ridge (X, Y, alpha = 0.1 )
59
62
N = length (Y)
63
+ # # Initialize a JuMP Model with OSQP solver
60
64
model = Model (() -> DiffOpt. diff_optimizer (OSQP. Optimizer))
61
65
set_silent (model)
62
- @variable (model, w)
63
- @variable (model, b)
64
- @variable (model, e[1 : N])
66
+ @variable (model, w) # angular coefficient
67
+ @variable (model, b) # linear coefficient
68
+ @variable (model, e[1 : N]) # approximation error
69
+ # # constraint defining approximation error
65
70
@constraint (model, cons[i= 1 : N], e[i] == Y[i] - w * X[i] - b)
71
+ # # objective minimizing squared error and ridge penalty
66
72
@objective (
67
73
model,
68
74
Min,
69
75
dot (e, e) + alpha * (sum (w * w) + sum (b * b)),
70
76
)
71
77
optimize! (model)
72
- return model, w, b, cons
78
+ return model, w, b, cons # return model, variables and constraints references
73
79
end
74
80
75
81
76
82
# Train on the data generated.
77
83
78
- model, w, b, cons = fitRidge (X, Y)
84
+ model, w, b, cons = fit_ridge (X, Y)
79
85
ŵ, b̂ = value (w), value (b)
80
86
81
87
# We can visualize the approximating line.
@@ -130,3 +136,5 @@ p = Plots.scatter(
130
136
mi, ma = minimum (X), maximum (X)
131
137
Plots. plot! (p, [mi, ma], [mi * ŵ + b̂, ma * ŵ + b̂], color = :red , label = " " )
132
138
139
+ # Note the points in the extremes of the line segment are larger because
140
+ # moving those points can affect more the angular coefficient of the line.
0 commit comments