Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fidelity metrics #6116

Merged
merged 14 commits into from
Dec 19, 2022
8 changes: 3 additions & 5 deletions torch_geometric/explain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from .explanation import Explanation
from .algorithm import * # noqa
from .explainer import Explainer
from .metrics import ExplanationMetric

__all__ = [
'ExplainerConfig',
'ModelConfig',
'ThresholdConfig',
'Explanation',
'Explainer',
'ExplainerConfig', 'ModelConfig', 'ThresholdConfig', 'Explanation',
'Explainer', 'ExplanationMetric'
]
3 changes: 2 additions & 1 deletion torch_geometric/explain/algorithm/gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def forward(
node_feat_mask, node_mask = node_mask, None

return Explanation(x=x, edge_index=edge_index, edge_mask=edge_mask,
node_mask=node_mask, node_feat_mask=node_feat_mask)
node_mask=node_mask, node_feat_mask=node_feat_mask,
index=index)

def _train(
self,
Expand Down
70 changes: 32 additions & 38 deletions torch_geometric/explain/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ThresholdConfig,
ThresholdType,
)
from torch_geometric.explain.util import hard_threshold, topk_threshold


class Explainer:
Expand All @@ -38,6 +39,7 @@ def __init__(
):
self.model = model
self.algorithm = algorithm
self.base_explanation = None

self.explainer_config = ExplainerConfig.cast(explainer_config)
self.model_config = ModelConfig.cast(model_config)
Expand Down Expand Up @@ -110,18 +112,7 @@ def __call__(
**kwargs: additional arguments to pass to the GNN.
"""
# Choose the `target` depending on the explanation type:
explanation_type = self.explainer_config.explanation_type
if explanation_type == ExplanationType.phenomenon:
if target is None:
raise ValueError(
f"The 'target' has to be provided for the explanation "
f"type '{explanation_type.value}'")
elif explanation_type == ExplanationType.model:
if target is not None:
warnings.warn(
f"The 'target' should not be provided for the explanation "
f"type '{explanation_type.value}'")
target = self.get_prediction(x=x, edge_index=edge_index, **kwargs)
target = self.get_target(target, x, edge_index, **kwargs)

training = self.model.training
self.model.eval()
Expand All @@ -138,8 +129,28 @@ def __call__(

self.model.train(training)

# store unprocessed explanation
self.base_explanation = explanation

return self._post_process(explanation)

def get_target(self, target, x, edge_index, **kwargs) -> Tensor:
explanation_type = self.explainer_config.explanation_type

if explanation_type == ExplanationType.phenomenon:
if target is None:
raise ValueError(
f"The 'target' has to be provided for the explanation "
f"type '{explanation_type.value}'")
elif explanation_type == ExplanationType.model:
if target is not None:
warnings.warn(
f"The 'target' should not be provided for the explanation "
f"type '{explanation_type.value}'")
target = self.get_prediction(x=x, edge_index=edge_index, **kwargs)

return target

def _post_process(self, explanation: Explanation) -> Explanation:
R"""Post-processes the explanation mask according to the thresholding
method and the user configuration.
Expand Down Expand Up @@ -168,32 +179,15 @@ def _threshold(self, explanation: Explanation) -> Explanation:
}

if self.threshold_config.type == ThresholdType.hard:
mask_dict = {
key: (mask > self.threshold_config.value).float()
for key, mask in mask_dict.items()
}

elif self.threshold_config.type in [
ThresholdType.topk,
ThresholdType.topk_hard,
]:
for key, mask in mask_dict.items():
if self.threshold_config.value >= mask.numel():
if self.threshold_config.type != ThresholdType.topk:
mask_dict[key] = torch.ones_like(mask)
continue

value, index = torch.topk(
mask.flatten(),
k=self.threshold_config.value,
)

out = torch.zeros_like(mask.flatten())
if self.threshold_config.type == ThresholdType.topk:
out[index] = value
else:
out[index] = 1.0
mask_dict[key] = out.reshape(mask.size())
mask_dict = hard_threshold(mask_dict, self.threshold_config.value)

elif self.threshold_config.type == ThresholdType.topk:
mask_dict = topk_threshold(mask_dict, self.threshold_config.value,
hard=False)

elif self.threshold_config.type == ThresholdType.topk_hard:
mask_dict = topk_threshold(mask_dict, self.threshold_config.value,
hard=True)

else:
raise NotImplementedError
Expand Down
8 changes: 7 additions & 1 deletion torch_geometric/explain/explanation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import List, Optional
from typing import List, Optional, Union

from torch import Tensor

Expand All @@ -14,6 +14,9 @@ class Explanation(Data):
also hold the original graph if needed.

Args:
index (Union[int, Tensor], optional): The index of the model
output that the explanation explains. Can be a single
index or a tensor of indices. (default: :obj:`None`)
node_mask (Tensor, optional): Node-level mask with shape
:obj:`[num_nodes]`. (default: :obj:`None`)
edge_mask (Tensor, optional): Edge-level mask with shape
Expand All @@ -26,13 +29,15 @@ class Explanation(Data):
"""
def __init__(
self,
index: Optional[Union[int, Tensor]] = None,
node_mask: Optional[Tensor] = None,
edge_mask: Optional[Tensor] = None,
node_feat_mask: Optional[Tensor] = None,
edge_feat_mask: Optional[Tensor] = None,
**kwargs,
):
super().__init__(
index=index,
node_mask=node_mask,
edge_mask=edge_mask,
node_feat_mask=node_feat_mask,
Expand All @@ -52,6 +57,7 @@ def validate(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the explanation"""
status = super().validate()

# TODO check that index is in node_mask
if 'node_mask' in self and self.num_nodes != self.node_mask.size(0):
status = False
warn_or_raise(
Expand Down
125 changes: 125 additions & 0 deletions torch_geometric/explain/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from abc import ABC, abstractmethod
from typing import Optional, Union

import torch
from torch import Tensor

from torch_geometric.explain import Explainer, Explanation


class ExplanationMetric(ABC):
r"""Abstract base class for explanation metrics."""
def __init__(self, explainer, index) -> None:
self.explainer = explainer
self.index = index

@abstractmethod
def __call__(self, explainer: Explainer, **kwargs):
r"""Computes the explanation metric for given explainer and explanation
Args:
explainer :obj:`~torch_geometric.explain.Explainer`
The explainer to evaluate
"""

def get_inputs(self):
r"""Obtain inputs all different inputs over which to compute the
metrics."""

@abstractmethod
def compute_metric(self):
r"""Compute the metric over all inputs."""

@abstractmethod
def aggregate(self):
r"""Aggregate metrics over all inputs"""


class Fidelity(ExplanationMetric):
r"""Fidelity+/- Explanation Metric as
defined in https://arxiv.org/abs/2206.09677"""
def __init__(self) -> None:
super().__init__()


def fidelity(explainer: Explainer, explanation: Explanation,
target: Optional[Tensor] = None,
index: Optional[Union[int, Tensor]] = None,
output_type: str = 'raw', **kwargs):
r"""Evaluate the fidelity of Explainer and given
explanation produced by explainer

Args:
explainer :obj:`~torch_geometric.explain.Explainer`
The explainer to evaluate
explanation :obj:`~torch_teometric.explain.Explanation`
The explanation to evaluate
target (Tensor, optional): The target prediction, if not provided it
is inferred from obj:`explainer`, defaults to obj:`None`
index (Union[int, Tensor]): The explanation target index, for node-
and edge- level task it signifies the nodes/edges explained
respectively, for graph-level tasks it is assumed to be None,
defaults to obj:`None`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
explainer :obj:`~torch_geometric.explain.Explainer`
The explainer to evaluate
explanation :obj:`~torch_teometric.explain.Explanation`
The explanation to evaluate
target (Tensor, optional): The target prediction, if not provided it
is inferred from obj:`explainer`, defaults to obj:`None`
index (Union[int, Tensor]): The explanation target index, for node-
and edge- level task it signifies the nodes/edges explained
respectively, for graph-level tasks it is assumed to be None,
defaults to obj:`None`
explanation :obj:`~torch_teometric.explain.Explanation`
The explanation to evaluate

Do we need anything other than explanation now given other values are added in to Explanation in Explainer.

"""
metric_dict = {}

task_level = explainer.model_config.task_level

if index != explanation.get('index'):
raise ValueError(f'Index ({index}) does not match '
f'explanation.index ({explanation.index}).')

# get input graphs
explanation_graph = explanation.get_explanation_subgraph() # for fid-
complement_graph = explanation.get_complement_subgraph() # for fid+

# get target
target = explainer.get_target(x=explanation.x,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rusty1s and @dufourc1, any ideas of how to make sure the right input is passed to get_target here, as well as get_prediction() calls on lines 81 and 84. I am unsure how to do this, given that the get_prediction calls get subgraphs as inputs, while explanation.get_explanation_subgraph and explanation.get_complement_subgraph don't return modified node and edge attributes needed to get the appropriate predictions. Am I missing a simple way to achieve this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, I think you have to pass the input that is explained with the model (i.e the one that was used to generate the explanation). I'm not sure if you can always make sure that an Explanation instance has a x or even an edge_index attributes.

You would need to apply the masks of the explanation to the original input to get the explanation_graph and complement_graph, but maybe I'm missing something in what you want to do ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BlazStojanovic is this still blocking you?

edge_index=explanation.edge_index,
**kwargs) # using full explanation

# get predictions
explanation_prediction = explainer.get_prediction(
x=explanation_graph.x, edge_index=explanation_graph.edge_index,
**kwargs)
complement_prediction = explainer.get_prediction(
x=complement_graph.x, edge_index=complement_graph.edge_index, **kwargs)

# fix logprob to prob
if output_type == 'prob' and explainer.model.return_type == 'log_probs':
target = torch.exp(target)
explanation_prediction = torch.exp(explanation_prediction)
complement_prediction = torch.exp(complement_prediction)

# based on task level
if task_level == 'graph':
if index is not None:
ValueError(
f'Index for graph level task should be None, got (f{index})')
# evaluate whole entry
pass
elif task_level == 'edge':
# get edge prediction
pass # TODO (blaz)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Node level and graph level tasks make sense to me. For nodes we will have a clearly indexed output (each row of the prediction corresponding to nodes) and for full graphs we only have one prediction anyway. This makes it clear on what to compare to get fidelities (both in classification and regression tasks).

But I am not certain for edge level tasks. Let say for example when obtaining the explanation for an edge level task with the explainer class, we use:

explanation = explainer(
        x,
        edge_index,
        target=target,
        index=index,
        edge_label_index=edge_label_index,
    )

What exactly does the index refer to, are we explaining the prediciton(s) for edge(s) with indices in index which correspond to edges in edge_index? Or something else? If so, how can we easily compare the predictions, as the edge_index changes for induced and complement graphs for an explanation?

Copy link
Member

@dufourc1 dufourc1 Dec 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the tests can help ?

edge_label_index = torch.tensor([[0, 1, 2], [3, 4, 5]])

I haven't taken a look at the edge level explainers in details, but it seems that for explaining the edge between node 0 and 1 you would have edge_label_index = torch.tensor([[0],[1]]).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BlazStojanovic I think we should leave this as NotImplemented for now. Because I think task_level=edge needs some more work

  1. We don't have an example to demonstrate task_level=edge on some real world data. We could have and example that explains this edge prediction task.
  2. GNNExplainer needs to be updated to support edge level tasks, one line that needs to be updated is
    if self.model_config.task_level == ModelTaskLevel.node:
    .

elif task_level == 'node':
# get node prediction(s)
target = target[index]
explanation_prediction = explanation_prediction[index]
complement_prediction = complement_prediction[index]
else:
raise NotImplementedError

with torch.no_grad():
if explainer.model_config.mode == 'regression':
metric_dict['fidelity-'] = torch.mean(
torch.abs(target - explanation_prediction))
metric_dict['fidelity+'] = torch.mean(
torch.abs(target - complement_prediction))
elif explainer.model_config.mode == 'classification':
metric_dict['fidelity-'] = torch.mean(
target == explanation_prediction)
metric_dict['fidelity+'] = torch.mean(
target == complement_prediction)
else:
raise NotImplementedError

return metric_dict
33 changes: 33 additions & 0 deletions torch_geometric/explain/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch


def hard_threshold(mask_dict, threshold):
"""Impose hard threshold on a dictionary of masks"""
mask_dict = {
key: (mask > threshold).float()
for key, mask in mask_dict.items()
}
return mask_dict


def topk_threshold(mask_dict, threshold, hard=False):
"""Impose topk threshold on a dictionary of masks"""
for key, mask in mask_dict.items():
if threshold >= mask.numel():
if hard:
mask_dict[key] = torch.ones_like(mask)
continue

value, index = torch.topk(
mask.flatten(),
k=threshold,
)

out = torch.zeros_like(mask.flatten())
if not hard:
out[index] = value
else:
out[index] = 1.0
mask_dict[key] = out.reshape(mask.size())

return mask_dict