diff --git a/Project.toml b/Project.toml index ab6403a4..a6d908ef 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ CategoricalDistributions = "0.1" ComputationalResources = "0.3" Distributions = "0.25.3" InvertedIndices = "1" -LossFunctions = "0.9" +LossFunctions = "0.10" MLJModelInterface = "1.7" Missings = "0.4, 1" OrderedCollections = "1.1" diff --git a/src/measures/loss_functions_interface.jl b/src/measures/loss_functions_interface.jl index 42e9bdbf..5d7d6125 100644 --- a/src/measures/loss_functions_interface.jl +++ b/src/measures/loss_functions_interface.jl @@ -130,7 +130,7 @@ MMI.prediction_type(::Type{<:DistanceLoss}) = :deterministic MMI.target_scitype(::Type{<:DistanceLoss}) = Union{Vec{Continuous},Vec{Count}} call(measure::DistanceLoss, yhat, y) = - LossFunctions.value(getfield(measure, :loss), yhat, y) + (getfield(measure, :loss)).(yhat, y) function call(measure::DistanceLoss, yhat, y, w::AbstractArray) return w .* call(measure, yhat, y) @@ -147,8 +147,8 @@ _scale(p) = 2p - 1 function call(measure::MarginLoss, yhat, y) probs_of_observed = broadcast(pdf, yhat, y) - return (LossFunctions.value).(getfield(measure, :loss), - _scale.(probs_of_observed), 1) + loss = getfield(measure, :loss) + return loss.(_scale.(probs_of_observed), 1) end call(measure::MarginLoss, yhat, y, w::AbstractArray) = diff --git a/test/measures/loss_functions_interface.jl b/test/measures/loss_functions_interface.jl index d5894eb8..8c59945b 100644 --- a/test/measures/loss_functions_interface.jl +++ b/test/measures/loss_functions_interface.jl @@ -42,9 +42,9 @@ end for M_ex in MARGIN_LOSSES m = eval(:(MLJBase.$M_ex())) - @test m(yhat, y) ≈ LossFunctions.value(getfield(m, :loss), yhatm, ym) + @test m(yhat, y) ≈ (getfield(m, :loss)).(yhatm, ym) @test m(yhat, y, w) ≈ - w .* LossFunctions.value(getfield(m, :loss), yhatm, ym) + w .* (getfield(m, :loss)).(yhatm, ym) end end @@ -61,8 +61,8 @@ end m_ex = MLJBase.snakecase(M_ex) @test m == eval(:(MLJBase.$m_ex)) @test m(yhat, y) ≈ - LossFunctions.value(getfield(m, :loss), yhat, y) + (getfield(m, :loss)).(yhat, y) @test m(yhat ,y, w) ≈ - w .* LossFunctions.value(getfield(m, :loss), yhat, y) + w .* (getfield(m, :loss)).(yhat, y) end end