From 2209563ebc7b8e16dd74455583011a29b82d2404 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 1 Jul 2024 11:49:58 +1200 Subject: [PATCH 1/3] relax restrictions on model type in resampling --- src/resampling.jl | 32 +++++++++++++++++++++----------- test/resampling.jl | 26 +++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/src/resampling.jl b/src/resampling.jl index dd317092..da8eac72 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -31,10 +31,6 @@ const ERR_INVALID_OPERATION = ArgumentError( _ambiguous_operation(model, measure) = "`$measure` does not support a `model` with "* "`prediction_type(model) == :$(prediction_type(model))`. " -err_ambiguous_operation(model, measure) = ArgumentError( - _ambiguous_operation(model, measure)* - "\nUnable to infer an appropriate operation for `$measure`. "* - "Explicitly specify `operation=...` or `operations=...`. ") err_incompatible_prediction_types(model, measure) = ArgumentError( _ambiguous_operation(model, measure)* "If your model is truly making probabilistic predictions, try explicitly "* @@ -65,11 +61,25 @@ ERR_MEASURES_DETERMINISTIC(measure) = ArgumentError( "and so is not supported by `$measure`. "*LOG_AVOID ) -# ================================================================== -## MODEL TYPES THAT CAN BE EVALUATED +err_ambiguous_operation(model, measure) = ArgumentError( + _ambiguous_operation(model, measure)* + "\nUnable to infer an appropriate operation for `$measure`. "* + "Explicitly specify `operation=...` or `operations=...`. "* + "Possible value(s) are: $PREDICT_OPERATIONS_STRING. " +) + +ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError( + """ + + The `prediction_type` of your model needs to be one of: `:deterministic`, + `:probabilistic`, or `:interval`. Does your model implement one of these operations: + $PREDICT_OPERATIONS_STRING? If so, you can try explicitly specifying `operation=...` + or `operations=...` (and consider posting an issue to have the model review it's + definition of `MLJModelInterface.prediction_type`). Otherwise, performance + evaluation is not supported. -# not exported: -const Measurable = Union{Supervised, Annotator} + """ +) # ================================================================== ## RESAMPLING STRATEGIES @@ -987,7 +997,7 @@ function _actual_operations(operation::Nothing, throw(err_ambiguous_operation(model, m)) end else - throw(err_ambiguous_operation(model, m)) + throw(ERR_UNSUPPORTED_PREDICTION_TYPE) end end end @@ -1137,7 +1147,7 @@ See also [`evaluate`](@ref), [`PerformanceEvaluation`](@ref), """ function evaluate!( - mach::Machine{<:Measurable}; + mach::Machine; resampling=CV(), measures=nothing, measure=measures, @@ -1235,7 +1245,7 @@ Returns a [`PerformanceEvaluation`](@ref) object. See also [`evaluate!`](@ref). """ -evaluate(model::Measurable, args...; cache=true, kwargs...) = +evaluate(model::Model, args...; cache=true, kwargs...) = evaluate!(machine(model, args...; cache=cache); kwargs...) # ------------------------------------------------------------------- diff --git a/test/resampling.jl b/test/resampling.jl index fbf26777..ecfd4d3d 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -25,6 +25,8 @@ end struct DummyInterval <: Interval end dummy_interval=DummyInterval() +struct GoofyTransformer <: Unsupervised end + dummy_measure_det(yhat, y) = 42 API.@trait( typeof(dummy_measure_det), @@ -115,6 +117,12 @@ API.@trait( MLJBase.err_ambiguous_operation(dummy_interval, LogLoss()), MLJBase._actual_operations(nothing, [LogLoss(), ], dummy_interval, 1)) + + # model not have a valid `prediction_type`: + @test_throws( + MLJBase.ERR_UNSUPPORTED_PREDICTION_TYPE, + MLJBase._actual_operations(nothing, [LogLoss(),], GoofyTransformer(), 0), + ) end @everywhere begin @@ -935,7 +943,23 @@ end end end -# DUMMY LOGGER + +# # TRANSFORMER WITH PREDICT + +struct PredictingTransformer <:Unsupervised end +MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing) +MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X)) +MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic + +@testset "`Unsupervised` model with a predict" begin + X = rand(10) + y = fill(42.0, 10) + e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2) + @test e.measurement[1] ≈ 0 +end + + +# # DUMMY LOGGER struct DummyLogger end From 140fb7bcf3f46d48120c6a35835c96992da6bbac Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 1 Jul 2024 13:37:46 +1200 Subject: [PATCH 2/3] add catch for missing target in resampling --- src/resampling.jl | 16 +++++++++++++++- test/resampling.jl | 6 ++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/resampling.jl b/src/resampling.jl index da8eac72..0c4ffda9 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -68,7 +68,7 @@ err_ambiguous_operation(model, measure) = ArgumentError( "Possible value(s) are: $PREDICT_OPERATIONS_STRING. " ) -ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError( +const ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError( """ The `prediction_type` of your model needs to be one of: `:deterministic`, @@ -81,6 +81,18 @@ ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError( """ ) +const ERR_NEED_TARGET = ArgumentError( + """ + + To evaluate a model's performance you must provide a target variable `y`, as in + `evaluate(model, X, y; options...)` or + + mach = machine(model, X, y) + evaluate!(mach; options...) + + """ +) + # ================================================================== ## RESAMPLING STRATEGIES @@ -1170,6 +1182,8 @@ function evaluate!( # weights, measures, operations, and dispatches a # strategy-specific `evaluate!` + length(mach.args) > 1 || throw(ERR_NEED_TARGET) + repeats > 0 || error("Need `repeats > 0`. ") if resampling isa TrainTestPairs diff --git a/test/resampling.jl b/test/resampling.jl index ecfd4d3d..a91c28c5 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -948,7 +948,9 @@ end struct PredictingTransformer <:Unsupervised end MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing) +MLJBase.fit(::PredictingTransformer, verbosity, X) = (nothing, nothing, nothing) MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X)) +MLJBase.predict(::PredictingTransformer, ::Nothing, X) = nothing MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic @testset "`Unsupervised` model with a predict" begin @@ -956,6 +958,10 @@ MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic y = fill(42.0, 10) e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2) @test e.measurement[1] ≈ 0 + @test_throws( + MLJBase.ERR_NEED_TARGET, + evaluate(PredictingTransformer(), X, measure=l2), + ) end From 64b94815715595dda348a45ded1e67790ac203a6 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 1 Jul 2024 15:40:10 +1200 Subject: [PATCH 3/3] code comment typo --- test/resampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/resampling.jl b/test/resampling.jl index a91c28c5..62812eec 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -118,7 +118,7 @@ API.@trait( MLJBase._actual_operations(nothing, [LogLoss(), ], dummy_interval, 1)) - # model not have a valid `prediction_type`: + # model does not have a valid `prediction_type`: @test_throws( MLJBase.ERR_UNSUPPORTED_PREDICTION_TYPE, MLJBase._actual_operations(nothing, [LogLoss(),], GoofyTransformer(), 0),