Skip to content

Commit 5d62b45

Browse files
committed
improve ridge sensib tutorial
1 parent 46a244f commit 5d62b45

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

docs/src/examples/sensitivity-analysis-ridge.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
# ```
3131

3232

33-
3433
# This tutorial uses the following packages
3534

3635
using JuMP
@@ -53,29 +52,36 @@ b = rand()
5352
X = randn(N)
5453
Y = w * X .+ b + 0.8 * randn(N);
5554

56-
# The helper method `fitRidge` defines and solves the corresponding model.
55+
# The helper method `fit_ridge` defines and solves the corresponding model.
56+
# The ridge regression is modeled with quadratic programming
57+
# (quadratic objective and linear constraints) and solved in generic methods
58+
# of OSQP. This is not the standard way of solving the ridge regression problem
59+
# this is done here for didactic purposes.
5760

58-
function fitRidge(X, Y, alpha = 0.1)
61+
function fit_ridge(X, Y, alpha = 0.1)
5962
N = length(Y)
63+
## Initialize a JuMP Model with OSQP solver
6064
model = Model(() -> DiffOpt.diff_optimizer(OSQP.Optimizer))
6165
set_silent(model)
62-
@variable(model, w)
63-
@variable(model, b)
64-
@variable(model, e[1:N])
66+
@variable(model, w) # angular coefficient
67+
@variable(model, b) # linear coefficient
68+
@variable(model, e[1:N]) # approximation error
69+
## constraint defining approximation error
6570
@constraint(model, cons[i=1:N], e[i] == Y[i] - w * X[i] - b)
71+
## objective minimizing squared error and ridge penalty
6672
@objective(
6773
model,
6874
Min,
6975
dot(e, e) + alpha * (sum(w * w) + sum(b * b)),
7076
)
7177
optimize!(model)
72-
return model, w, b, cons
78+
return model, w, b, cons # return model, variables and constraints references
7379
end
7480

7581

7682
# Train on the data generated.
7783

78-
model, w, b, cons = fitRidge(X, Y)
84+
model, w, b, cons = fit_ridge(X, Y)
7985
ŵ, b̂ = value(w), value(b)
8086

8187
# We can visualize the approximating line.
@@ -130,3 +136,5 @@ p = Plots.scatter(
130136
mi, ma = minimum(X), maximum(X)
131137
Plots.plot!(p, [mi, ma], [mi *+ b̂, ma *+ b̂], color = :red, label = "")
132138

139+
# Note the points in the extremes of the line segment are larger because
140+
# moving those points can affect more the angular coefficient of the line.

0 commit comments

Comments
 (0)