Skip to content

Commit

Permalink
first support for units in fitlinear (needs tests)
Browse files Browse the repository at this point in the history
  • Loading branch information
lmiq committed Sep 18, 2024
1 parent b99bee3 commit 4c0c877
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 34 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ LsqFit = "2fda8390-95c7-5789-9bda-21331edee243"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[weakdeps]
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Expand All @@ -22,6 +23,7 @@ LsqFit = "0.11, 0.12, 0.13, 0.14, 0.15"
Parameters = "0.12"
Statistics = "1"
TestItems = "0.1"
Unitful = "1.21.0"
julia = "1.9"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/EasyFit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using TestItems
using Statistics
using LsqFit
using Parameters
using Unitful: ustrip, oneunit

# supertype for all fits, to help on dispatch of common methods
abstract type Fit{T<:AbstractFloat} end
Expand Down
6 changes: 4 additions & 2 deletions src/checkdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ function check_size_and_type(X, c, data_type)
return vec(data_type.(X))
end

function checkdata(X::AbstractArray{T1}, Y::AbstractArray{T2}, options::Options) where {T1<:Real, T2<:Real}
function checkdata(X::AbstractArray{T1}, Y::AbstractArray{T2}, options::Options) where {T1<:Number, T2<:Number}
if length(X) != length(Y)
throw(ArgumentError("Input x and y vectors must have the same length."))
end
data_type = promote_type(Float32,T1,T2)
X = ustrip.(X)
Y = ustrip.(Y)
data_type = promote_type(Float32,eltype(X),eltype(Y))
X = check_size_and_type(X, "X", data_type)
Y = check_size_and_type(Y, "Y", data_type)
# Set some reasonable ranges for the initial guesses
Expand Down
65 changes: 33 additions & 32 deletions src/fitlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# Linear fit
#

struct Linear{T} <: Fit{T}
a::T
b::T
R::T
x::Vector{T}
y::Vector{T}
ypred::Vector{T}
residues::Vector{T}
@kwdef struct Linear{TA,TR,TX,TY,TRES}
a::TA
b::TY
R::TR
x::Vector{TX}
y::Vector{TY}
ypred::Vector{TY}
residues::Vector{TY}
end

"""
Expand Down Expand Up @@ -50,39 +50,39 @@ function fitlinear(
X::AbstractArray{T1}, Y::AbstractArray{T2};
l::lower=lower(), u::upper=upper(), b=nothing,
options::Options=Options()
) where {T1<:Real, T2<:Real}
) where {T1<:Number, T2<:Number}
# Check units
# Check data
onex = oneunit(T1)
oney = oneunit(T2)
X, Y, data_type = checkdata(X, Y, options)
# Set bounds
vars = [VarType(:a, Number, 1), VarType(:b, Nothing, 1)]
lower, upper = setbounds(vars, l, u, data_type)
if isnothing(b)
# Set model
@. model(x, p) = p[1] * x + p[2]
# Initial point
p0 = Vector{data_type}(undef, 2)
initP!(p0, options, lower, upper)
# Fit
fit = curve_fit(model, X, Y, p0, lower=lower, upper=upper)
# Analyze results and return
R = pearson(X, Y, model, fit)
x, y, ypred = finexy(X, length(X), model, fit)
return Linear(fit.param..., R, x, y, ypred, fit.resid)
else
lower = [lower[1]]
upper = [upper[1]]
# Set model
@. model_const(x, p) = p[1] * x + b
# Initial point
p0 = Vector{data_type}(undef, 1)
initP!(p0, options, lower, upper)
# Fit
fit = curve_fit(model_const, X, Y, p0, lower=lower, upper=upper)
# Analyze results and return
R = pearson(X, Y, model_const, fit)
x, y, ypred = finexy(X, length(X), model_const, fit)
return Linear(fit.param..., b, R, x, y, ypred, fit.resid)
end
initP!(p0, options, lower, upper)
# Fit
fit = curve_fit(model, X, Y, p0, lower=lower, upper=upper)
# Analyze results and return
R = pearson(X, Y, model, fit)
x, y, ypred = finexy(X, length(X), model, fit)
return Linear(
a=fit.param[1] * oney/onex,
b=fit.param[2] * oney,
x=X .* onex,
y=Y .* oney,
R=R .* onex * oney,
ypred=ypred .* oney,
residues=fit.resid .* oney
)
end

"""
Expand Down Expand Up @@ -127,14 +127,14 @@ julia> f.(rand(10))
0.40113908380656205
```
"""
function (fit::Linear)(x::Real)
function (fit::Linear)(x::Number)
a = fit.a
b = fit.b
return a * x + b
end

function Base.show(io::IO, fit::Linear)
println(io,
println(io,chomp(
"""
------------------- Linear Fit -------------
Expand All @@ -149,8 +149,9 @@ function Base.show(io::IO, fit::Linear)
Predicted Y: ypred = [$(fit.ypred[1]), $(fit.ypred[2]), ...]
residues = [$(fit.residues[1]), $(fit.residues[2]), ...]
--------------------------------------------"""
)
--------------------------------------------
"""
))
end

export fitlinear
Expand All @@ -170,5 +171,5 @@ export fitlinear
x = Float32.(x)
y = Float32.(y)
f = fitlinear(x, y)
@test typeof(f.R) == Float32
end

0 comments on commit 4c0c877

Please sign in to comment.