-
Notifications
You must be signed in to change notification settings - Fork 10
/
ScaledRegularization.jl
78 lines (69 loc) · 3.51 KB
/
ScaledRegularization.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
export AbstractScaledRegularization
"""
AbstractScaledRegularization
Nested regularization term that applies a `scalefactor` to the regularization parameter `λ` of its `inner` term.
See also [`scalefactor`](@ref), [`λ`](@ref), [`innerreg`](@ref).
"""
abstract type AbstractScaledRegularization{T, S<:AbstractParameterizedRegularization{<:Union{T, <:AbstractArray{T}}}} <: AbstractNestedRegularization{S} end
"""
scalescalefactor(reg::AbstractScaledRegularization)
return the scaling `scalefactor` for `λ`
"""
scalefactor(::R) where R <: AbstractScaledRegularization = error("Scaled regularization term $R must implement scalefactor")
"""
λ(reg::AbstractScaledRegularization)
return `λ` of `inner` regularization term scaled by `scalefactor(reg)`.
See also [`scalefactor`](@ref), [`innerreg`](@ref).
"""
λ(reg::AbstractScaledRegularization) = λ(innerreg(reg)) .* scalefactor(reg)
export FixedScaledRegularization
struct FixedScaledRegularization{T, S, R} <: AbstractScaledRegularization{T, S}
reg::R
factor::T
FixedScaledRegularization(reg::R, factor) where {T, R <: AbstractParameterizedRegularization{T}} = new{T, R, R}(reg, factor)
FixedScaledRegularization(reg::R, factor) where {T, RN <: AbstractParameterizedRegularization{T}, R<:AbstractNestedRegularization{RN}} = new{T, RN, R}(reg, factor)
end
innerreg(reg::FixedScaledRegularization) = reg.reg
scalefactor(reg::FixedScaledRegularization) = reg.factor
export FixedParameterRegularization
"""
FixedParameterRegularization
Nested regularization term that discards any `λ` passed to it and instead uses `λ` from its inner regularization term. This can be used to selectively disallow normalization.
"""
struct FixedParameterRegularization{T, S, R} <: AbstractScaledRegularization{T, S}
reg::R
FixedParameterRegularization(reg::R) where {T, R <: AbstractParameterizedRegularization{T}} = new{T, R, R}(reg)
FixedScaledRegularization(reg::R) where {T, RN <: AbstractParameterizedRegularization{T}, R<:AbstractNestedRegularization{RN}} = new{T, RN, R}(reg)
end
scalefactor(reg::FixedParameterRegularization) = 1.0
innerreg(reg::FixedParameterRegularization) = reg.reg
# Drop any incoming λ and subsitute inner
prox!(reg::FixedParameterRegularization, x, discard) = prox!(innerreg(reg), x, λ(innerreg(reg)))
norm(reg::FixedParameterRegularization, x, discard) = norm(innerreg(reg), x, λ(innerreg(reg)))
export AutoScaledRegularization
mutable struct AutoScaledRegularization{T, S, R} <: AbstractScaledRegularization{T, S}
reg::R
factor::Union{Nothing, T}
AutoScaledRegularization(reg::R) where {T, R <: AbstractParameterizedRegularization{T}} = new{T, R, R}(reg, nothing)
AutoScaledRegularization(reg::R) where {T, RN <: AbstractParameterizedRegularization{T}, R<:AbstractNestedRegularization{RN}} = new{T, RN, R}(reg, nothing)
end
initFactor!(reg::AutoScaledRegularization, x::AbstractArray) = reg.factor = maximum(abs.(x))
innerreg(reg::AutoScaledRegularization) = reg.reg
# A bit hacky: Factor can only be computed once x is seen, therefore hide factor in λ and silently add it in prox!/norm calls
scalefactor(reg::AutoScaledRegularization) = isnothing(reg.factor) ? 1.0 : reg.factor
function prox!(reg::AutoScaledRegularization, x, λ)
if isnothing(reg.factor)
initFactor!(reg, x)
return prox!(reg.reg, x, λ * reg.factor)
else
return prox!(reg.reg, x, λ)
end
end
function norm(reg::AutoScaledRegularization, x, λ)
if isnothing(reg.factor)
initFactor!(reg, x)
return norm(reg.reg, x, λ * reg.factor)
else
return norm(reg.reg, x, λ)
end
end