diff --git a/CHANGELOG.md b/CHANGELOG.md index a808d7b22..181e70b7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ## Version [1.3.6] +### Changed + +- Slight changes to the implementation of `ProbeGenerator` (no longer calling a redundant `hinge_loss` function for all other generators). + +### Added + +- Added a warning message to the `ProbeGenerator` pointing to the issues with with current implementation. + ## Version [1.3.5] - 2024-10-28 ### Changed diff --git a/src/generators/gradient_based/generators.jl b/src/generators/gradient_based/generators.jl index a1ac3381a..4f1c48401 100644 --- a/src/generators/gradient_based/generators.jl +++ b/src/generators/gradient_based/generators.jl @@ -65,6 +65,7 @@ function ProbeGenerator(; penalty=[Objectives.distance_l1, Objectives.hinge_loss], kwargs..., ) + @warn "The `ProbeGenerator` is currenlty not working adequately. In particular, gradients are not computed with respect to the Hinge loss term proposed in the paper. It is still possible, however, to use this generator to achieve a desired invalidation rate. See issue [#376](https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl/issues/376) for details." user_loss = Objectives.losses_catalogue[loss] return GradientBasedGenerator(; loss=user_loss, penalty=penalty, λ=λ, kwargs...) end diff --git a/src/generators/gradient_based/utils.jl b/src/generators/gradient_based/utils.jl index 2425aab1a..409bc1aaa 100644 --- a/src/generators/gradient_based/utils.jl +++ b/src/generators/gradient_based/utils.jl @@ -16,6 +16,11 @@ By default, gradient-based search is considered to have converged as soon as the function Convergence.conditions_satisfied( generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation ) + if !(ce.convergence isa Convergence.GeneratorConditionsConvergence) + # Temporary fix due to the fact that `ProbeGenerator` relies on `InvalidationRateConvergence`. + @warn "Checking for generator conditions convergence is not implemented for this generator type. Return `false`." maxlog=1 + return false + end Δcounterfactual_state = ∇(generator, ce) Δcounterfactual_state = CounterfactualExplanations.apply_mutability( ce, Δcounterfactual_state