From 4d3ff50fe78e8a7516afd7849f77bfb2de264731 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 30 Oct 2024 14:58:58 +0100 Subject: [PATCH] now maybe --- CHANGELOG.md | 1 + src/generators/gradient_based/generators.jl | 44 ++++++++++++++++----- src/generators/gradient_based/utils.jl | 2 +- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 181e70b7b..61b129eef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ### Added - Added a warning message to the `ProbeGenerator` pointing to the issues with with current implementation. +- Added links to papers to all docstrings for generators. ## Version [1.3.5] - 2024-10-28 diff --git a/src/generators/gradient_based/generators.jl b/src/generators/gradient_based/generators.jl index 4f1c48401..0fd0f8cb7 100644 --- a/src/generators/gradient_based/generators.jl +++ b/src/generators/gradient_based/generators.jl @@ -5,49 +5,63 @@ function GenericGenerator(; λ::AbstractFloat=0.1, kwargs...) return GradientBasedGenerator(; penalty=default_distance, λ=λ, kwargs...) end -"Constructor for `ECCoGenerator`. This corresponds to the generator proposed in https://arxiv.org/abs/2312.10648, without the conformal set size penalty." +const DOC_ECCCo = "For details, see Altmeyer et al. ([2024](https://ojs.aaai.org/index.php/AAAI/article/view/28956))." + +"Constructor for `ECCoGenerator`. This corresponds to the generator proposed in https://arxiv.org/abs/2312.10648, without the conformal set size penalty. $DOC_ECCCo" function ECCoGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 1.0], kwargs...) _penalties = [default_distance, Objectives.energy_constraint] return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...) end -"Constructor for `WachterGenerator`." +const DOC_Wachter = "For details, see Wachter et al. ([2018](https://arxiv.org/abs/1711.00399))." + +"Constructor for `WachterGenerator`. $DOC_Wachter" function WachterGenerator(; λ::AbstractFloat=0.1, kwargs...) return GradientBasedGenerator(; penalty=Objectives.distance_mad, λ=λ, kwargs...) end -"Constructor for `DiCEGenerator`." +const DOC_DiCE = "For details, see Mothilal et al. ([2020](https://arxiv.org/abs/1905.07697))." + +"Constructor for `DiCEGenerator`. $DOC_DiCE" function DiCEGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 0.1], kwargs...) _penalties = [default_distance, Objectives.ddp_diversity] return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...) end -"Constructor for `ClaPGenerator`." +const DOC_SaTML = "For details, see Altmeyer et al. ([2023](https://ieeexplore.ieee.org/abstract/document/10136130))." + +"Constructor for `ClaPGenerator`. $DOC_SaTML" function ClaPROARGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 0.5], kwargs...) _penalties = [default_distance, Objectives.model_loss_penalty] return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...) end -"Constructor for `GravitationalGenerator`." +"Constructor for `GravitationalGenerator`. $DOC_SaTML" function GravitationalGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 0.5], kwargs...) _penalties = [default_distance, Objectives.distance_from_target] return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...) end -"Constructor for `REVISEGenerator`." +const DOC_REVISE = "For details, see Joshi et al. ([2019](https://arxiv.org/abs/1907.09615))." + +"Constructor for `REVISEGenerator`. $DOC_REVISE" function REVISEGenerator(; λ::AbstractFloat=0.1, latent_space=true, kwargs...) return GradientBasedGenerator(; penalty=default_distance, λ=λ, latent_space=latent_space, kwargs... ) end -"Constructor for `GreedyGenerator`." +const DOC_Greedy = "For details, see Schut et al. ([2021](https://proceedings.mlr.press/v130/schut21a/schut21a.pdf))." + +"Constructor for `GreedyGenerator`. $DOC_Greedy" function GreedyGenerator(; η=0.1, n=nothing, kwargs...) opt = CounterfactualExplanations.Generators.JSMADescent(; η=η, n=n) return GradientBasedGenerator(; penalty=default_distance, λ=0.0, opt=opt, kwargs...) end -"Constructor for `CLUEGenerator`." +const DOC_CLUE = "For details, see Antoran et al. ([2021](https://arxiv.org/abs/2006.06848))." + +"Constructor for `CLUEGenerator`. $DOC_CLUE" function CLUEGenerator(; λ::AbstractFloat=0.1, latent_space=true, kwargs...) return GradientBasedGenerator(; loss=predictive_entropy, @@ -58,14 +72,24 @@ function CLUEGenerator(; λ::AbstractFloat=0.1, latent_space=true, kwargs...) ) end -"Constructor for `ProbeGenerator`." +const DOC_Probe = "For details, see Pawelczyk et al. ([2022](https://proceedings.mlr.press/v151/pawelczyk22a/pawelczyk22a.pdf))." + +const DOC_Probe_warn = "The `ProbeGenerator` is currenlty not working adequately. In particular, gradients are not computed with respect to the Hinge loss term proposed in the paper. It is still possible, however, to use this generator to achieve a desired invalidation rate. See issue [#376](https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl/issues/376) for details." + +""" +Constructor for `ProbeGenerator`. $DOC_Probe + +## Warning + +$DOC_Probe_warn +""" function ProbeGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 1.0], loss::Symbol=:logitbinarycrossentropy, penalty=[Objectives.distance_l1, Objectives.hinge_loss], kwargs..., ) - @warn "The `ProbeGenerator` is currenlty not working adequately. In particular, gradients are not computed with respect to the Hinge loss term proposed in the paper. It is still possible, however, to use this generator to achieve a desired invalidation rate. See issue [#376](https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl/issues/376) for details." + @warn DOC_Probe_warn user_loss = Objectives.losses_catalogue[loss] return GradientBasedGenerator(; loss=user_loss, penalty=penalty, λ=λ, kwargs...) end diff --git a/src/generators/gradient_based/utils.jl b/src/generators/gradient_based/utils.jl index 409bc1aaa..5fae252e5 100644 --- a/src/generators/gradient_based/utils.jl +++ b/src/generators/gradient_based/utils.jl @@ -16,7 +16,7 @@ By default, gradient-based search is considered to have converged as soon as the function Convergence.conditions_satisfied( generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation ) - if !(ce.convergence isa Convergence.GeneratorConditionsConvergence) + if !hasfield(ce.convergence, :gradient_tol) # Temporary fix due to the fact that `ProbeGenerator` relies on `InvalidationRateConvergence`. @warn "Checking for generator conditions convergence is not implemented for this generator type. Return `false`." maxlog=1 return false