-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fidelity metrics #6116
Fidelity metrics #6116
Changes from 6 commits
c8006ce
168bb0a
668c12c
967be5f
bc30302
b33753d
36796e4
3366169
ae6d8f9
70a7336
6b6f48e
d0d51e9
d02c180
a986d59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,125 @@ | ||||||
from abc import ABC, abstractmethod | ||||||
from typing import Optional, Union | ||||||
|
||||||
import torch | ||||||
from torch import Tensor | ||||||
|
||||||
from torch_geometric.explain import Explainer, Explanation | ||||||
|
||||||
|
||||||
class ExplanationMetric(ABC): | ||||||
r"""Abstract base class for explanation metrics.""" | ||||||
def __init__(self, explainer, index) -> None: | ||||||
self.explainer = explainer | ||||||
self.index = index | ||||||
|
||||||
@abstractmethod | ||||||
def __call__(self, explainer: Explainer, **kwargs): | ||||||
r"""Computes the explanation metric for given explainer and explanation | ||||||
Args: | ||||||
explainer :obj:`~torch_geometric.explain.Explainer` | ||||||
The explainer to evaluate | ||||||
""" | ||||||
|
||||||
def get_inputs(self): | ||||||
r"""Obtain inputs all different inputs over which to compute the | ||||||
metrics.""" | ||||||
|
||||||
@abstractmethod | ||||||
def compute_metric(self): | ||||||
r"""Compute the metric over all inputs.""" | ||||||
|
||||||
@abstractmethod | ||||||
def aggregate(self): | ||||||
r"""Aggregate metrics over all inputs""" | ||||||
|
||||||
|
||||||
class Fidelity(ExplanationMetric): | ||||||
r"""Fidelity+/- Explanation Metric as | ||||||
defined in https://arxiv.org/abs/2206.09677""" | ||||||
def __init__(self) -> None: | ||||||
super().__init__() | ||||||
|
||||||
|
||||||
def fidelity(explainer: Explainer, explanation: Explanation, | ||||||
target: Optional[Tensor] = None, | ||||||
index: Optional[Union[int, Tensor]] = None, | ||||||
output_type: str = 'raw', **kwargs): | ||||||
r"""Evaluate the fidelity of Explainer and given | ||||||
explanation produced by explainer | ||||||
|
||||||
Args: | ||||||
explainer :obj:`~torch_geometric.explain.Explainer` | ||||||
The explainer to evaluate | ||||||
explanation :obj:`~torch_teometric.explain.Explanation` | ||||||
The explanation to evaluate | ||||||
target (Tensor, optional): The target prediction, if not provided it | ||||||
is inferred from obj:`explainer`, defaults to obj:`None` | ||||||
index (Union[int, Tensor]): The explanation target index, for node- | ||||||
and edge- level task it signifies the nodes/edges explained | ||||||
respectively, for graph-level tasks it is assumed to be None, | ||||||
defaults to obj:`None` | ||||||
""" | ||||||
metric_dict = {} | ||||||
|
||||||
task_level = explainer.model_config.task_level | ||||||
|
||||||
if index != explanation.get('index'): | ||||||
raise ValueError(f'Index ({index}) does not match ' | ||||||
f'explanation.index ({explanation.index}).') | ||||||
|
||||||
# get input graphs | ||||||
explanation_graph = explanation.get_explanation_subgraph() # for fid- | ||||||
complement_graph = explanation.get_complement_subgraph() # for fid+ | ||||||
|
||||||
# get target | ||||||
target = explainer.get_target(x=explanation.x, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rusty1s and @dufourc1, any ideas of how to make sure the right input is passed to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly, I think you have to pass the input that is explained with the model (i.e the one that was used to generate the explanation). I'm not sure if you can always make sure that an You would need to apply the masks of the explanation to the original input to get the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BlazStojanovic is this still blocking you? |
||||||
edge_index=explanation.edge_index, | ||||||
**kwargs) # using full explanation | ||||||
|
||||||
# get predictions | ||||||
explanation_prediction = explainer.get_prediction( | ||||||
x=explanation_graph.x, edge_index=explanation_graph.edge_index, | ||||||
**kwargs) | ||||||
complement_prediction = explainer.get_prediction( | ||||||
x=complement_graph.x, edge_index=complement_graph.edge_index, **kwargs) | ||||||
|
||||||
# fix logprob to prob | ||||||
if output_type == 'prob' and explainer.model.return_type == 'log_probs': | ||||||
target = torch.exp(target) | ||||||
explanation_prediction = torch.exp(explanation_prediction) | ||||||
complement_prediction = torch.exp(complement_prediction) | ||||||
|
||||||
# based on task level | ||||||
if task_level == 'graph': | ||||||
if index is not None: | ||||||
ValueError( | ||||||
f'Index for graph level task should be None, got (f{index})') | ||||||
# evaluate whole entry | ||||||
pass | ||||||
elif task_level == 'edge': | ||||||
# get edge prediction | ||||||
pass # TODO (blaz) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Node level and graph level tasks make sense to me. For nodes we will have a clearly indexed output (each row of the prediction corresponding to nodes) and for full graphs we only have one prediction anyway. This makes it clear on what to compare to get fidelities (both in classification and regression tasks). But I am not certain for edge level tasks. Let say for example when obtaining the explanation for an edge level task with the explainer class, we use:
What exactly does the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe the tests can help ?
I haven't taken a look at the edge level explainers in details, but it seems that for explaining the edge between node There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BlazStojanovic I think we should leave this as
|
||||||
elif task_level == 'node': | ||||||
# get node prediction(s) | ||||||
target = target[index] | ||||||
explanation_prediction = explanation_prediction[index] | ||||||
complement_prediction = complement_prediction[index] | ||||||
else: | ||||||
raise NotImplementedError | ||||||
|
||||||
with torch.no_grad(): | ||||||
if explainer.model_config.mode == 'regression': | ||||||
metric_dict['fidelity-'] = torch.mean( | ||||||
torch.abs(target - explanation_prediction)) | ||||||
metric_dict['fidelity+'] = torch.mean( | ||||||
torch.abs(target - complement_prediction)) | ||||||
elif explainer.model_config.mode == 'classification': | ||||||
metric_dict['fidelity-'] = torch.mean( | ||||||
target == explanation_prediction) | ||||||
metric_dict['fidelity+'] = torch.mean( | ||||||
target == complement_prediction) | ||||||
else: | ||||||
raise NotImplementedError | ||||||
|
||||||
return metric_dict |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
|
||
|
||
def hard_threshold(mask_dict, threshold): | ||
"""Impose hard threshold on a dictionary of masks""" | ||
mask_dict = { | ||
key: (mask > threshold).float() | ||
for key, mask in mask_dict.items() | ||
} | ||
return mask_dict | ||
|
||
|
||
def topk_threshold(mask_dict, threshold, hard=False): | ||
"""Impose topk threshold on a dictionary of masks""" | ||
for key, mask in mask_dict.items(): | ||
if threshold >= mask.numel(): | ||
if hard: | ||
mask_dict[key] = torch.ones_like(mask) | ||
continue | ||
|
||
value, index = torch.topk( | ||
mask.flatten(), | ||
k=threshold, | ||
) | ||
|
||
out = torch.zeros_like(mask.flatten()) | ||
if not hard: | ||
out[index] = value | ||
else: | ||
out[index] = 1.0 | ||
mask_dict[key] = out.reshape(mask.size()) | ||
|
||
return mask_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need anything other than
explanation
now given other values are added in toExplanation
inExplainer
.