Skip to content

Commit

Permalink
now maybe
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Oct 30, 2024
1 parent e366cfc commit 4d3ff50
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
### Added

- Added a warning message to the `ProbeGenerator` pointing to the issues with with current implementation.
- Added links to papers to all docstrings for generators.

## Version [1.3.5] - 2024-10-28

Expand Down
44 changes: 34 additions & 10 deletions src/generators/gradient_based/generators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,63 @@ function GenericGenerator(; λ::AbstractFloat=0.1, kwargs...)
return GradientBasedGenerator(; penalty=default_distance, λ=λ, kwargs...)
end

"Constructor for `ECCoGenerator`. This corresponds to the generator proposed in https://arxiv.org/abs/2312.10648, without the conformal set size penalty."
const DOC_ECCCo = "For details, see Altmeyer et al. ([2024](https://ojs.aaai.org/index.php/AAAI/article/view/28956))."

"Constructor for `ECCoGenerator`. This corresponds to the generator proposed in https://arxiv.org/abs/2312.10648, without the conformal set size penalty. $DOC_ECCCo"
function ECCoGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 1.0], kwargs...)
_penalties = [default_distance, Objectives.energy_constraint]
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end

"Constructor for `WachterGenerator`."
const DOC_Wachter = "For details, see Wachter et al. ([2018](https://arxiv.org/abs/1711.00399))."

"Constructor for `WachterGenerator`. $DOC_Wachter"
function WachterGenerator(; λ::AbstractFloat=0.1, kwargs...)
return GradientBasedGenerator(; penalty=Objectives.distance_mad, λ=λ, kwargs...)
end

"Constructor for `DiCEGenerator`."
const DOC_DiCE = "For details, see Mothilal et al. ([2020](https://arxiv.org/abs/1905.07697))."

"Constructor for `DiCEGenerator`. $DOC_DiCE"
function DiCEGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 0.1], kwargs...)
_penalties = [default_distance, Objectives.ddp_diversity]
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end

"Constructor for `ClaPGenerator`."
const DOC_SaTML = "For details, see Altmeyer et al. ([2023](https://ieeexplore.ieee.org/abstract/document/10136130))."

"Constructor for `ClaPGenerator`. $DOC_SaTML"
function ClaPROARGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 0.5], kwargs...)
_penalties = [default_distance, Objectives.model_loss_penalty]
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end

"Constructor for `GravitationalGenerator`."
"Constructor for `GravitationalGenerator`. $DOC_SaTML"
function GravitationalGenerator(; λ::Vector{<:AbstractFloat}=[0.1, 0.5], kwargs...)
_penalties = [default_distance, Objectives.distance_from_target]
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end

"Constructor for `REVISEGenerator`."
const DOC_REVISE = "For details, see Joshi et al. ([2019](https://arxiv.org/abs/1907.09615))."

"Constructor for `REVISEGenerator`. $DOC_REVISE"
function REVISEGenerator(; λ::AbstractFloat=0.1, latent_space=true, kwargs...)
return GradientBasedGenerator(;
penalty=default_distance, λ=λ, latent_space=latent_space, kwargs...
)
end

"Constructor for `GreedyGenerator`."
const DOC_Greedy = "For details, see Schut et al. ([2021](https://proceedings.mlr.press/v130/schut21a/schut21a.pdf))."

"Constructor for `GreedyGenerator`. $DOC_Greedy"
function GreedyGenerator(; η=0.1, n=nothing, kwargs...)
opt = CounterfactualExplanations.Generators.JSMADescent(; η=η, n=n)
return GradientBasedGenerator(; penalty=default_distance, λ=0.0, opt=opt, kwargs...)
end

"Constructor for `CLUEGenerator`."
const DOC_CLUE = "For details, see Antoran et al. ([2021](https://arxiv.org/abs/2006.06848))."

"Constructor for `CLUEGenerator`. $DOC_CLUE"
function CLUEGenerator(; λ::AbstractFloat=0.1, latent_space=true, kwargs...)
return GradientBasedGenerator(;
loss=predictive_entropy,
Expand All @@ -58,14 +72,24 @@ function CLUEGenerator(; λ::AbstractFloat=0.1, latent_space=true, kwargs...)
)
end

"Constructor for `ProbeGenerator`."
const DOC_Probe = "For details, see Pawelczyk et al. ([2022](https://proceedings.mlr.press/v151/pawelczyk22a/pawelczyk22a.pdf))."

const DOC_Probe_warn = "The `ProbeGenerator` is currenlty not working adequately. In particular, gradients are not computed with respect to the Hinge loss term proposed in the paper. It is still possible, however, to use this generator to achieve a desired invalidation rate. See issue [#376](https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl/issues/376) for details."

"""
Constructor for `ProbeGenerator`. $DOC_Probe
## Warning
$DOC_Probe_warn
"""
function ProbeGenerator(;
λ::Vector{<:AbstractFloat}=[0.1, 1.0],
loss::Symbol=:logitbinarycrossentropy,
penalty=[Objectives.distance_l1, Objectives.hinge_loss],
kwargs...,
)
@warn "The `ProbeGenerator` is currenlty not working adequately. In particular, gradients are not computed with respect to the Hinge loss term proposed in the paper. It is still possible, however, to use this generator to achieve a desired invalidation rate. See issue [#376](https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl/issues/376) for details."
@warn DOC_Probe_warn
user_loss = Objectives.losses_catalogue[loss]
return GradientBasedGenerator(; loss=user_loss, penalty=penalty, λ=λ, kwargs...)
end
2 changes: 1 addition & 1 deletion src/generators/gradient_based/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ By default, gradient-based search is considered to have converged as soon as the
function Convergence.conditions_satisfied(
generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation
)
if !(ce.convergence isa Convergence.GeneratorConditionsConvergence)
if !hasfield(ce.convergence, :gradient_tol)
# Temporary fix due to the fact that `ProbeGenerator` relies on `InvalidationRateConvergence`.
@warn "Checking for generator conditions convergence is not implemented for this generator type. Return `false`." maxlog=1
return false
Expand Down

0 comments on commit 4d3ff50

Please sign in to comment.