Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fidelity metrics #6116

Merged
merged 14 commits into from
Dec 19, 2022
Prev Previous commit
Next Next commit
Merge branch 'master' into explanation_metrics
  • Loading branch information
BlazStojanovic authored Dec 14, 2022
commit 36796e449aa69b40ac298d2db68fbc993238f4f7
8 changes: 6 additions & 2 deletions torch_geometric/explain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from .metrics import ExplanationMetric

__all__ = [
'ExplainerConfig', 'ModelConfig', 'ThresholdConfig', 'Explanation',
'Explainer', 'ExplanationMetric'
'ExplainerConfig',
'ModelConfig',
'ThresholdConfig',
'Explanation',
'HeteroExplanation',
'Explainer',
]
5 changes: 2 additions & 3 deletions torch_geometric/explain/algorithm/gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,8 @@ def forward(
}:
node_feat_mask, node_mask = node_mask, None

return Explanation(x=x, edge_index=edge_index, edge_mask=edge_mask,
node_mask=node_mask, node_feat_mask=node_feat_mask,
index=index)
return Explanation(node_mask=node_mask, node_feat_mask=node_feat_mask,
edge_mask=edge_mask)

def _train(
self,
Expand Down
81 changes: 37 additions & 44 deletions torch_geometric/explain/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
ModelMode,
ThresholdConfig,
)
from torch_geometric.explain.util import hard_threshold, topk_threshold

from torch_geometric.typing import EdgeType, NodeType



class Explainer:
Expand Down Expand Up @@ -154,7 +156,20 @@ def __call__(
**kwargs: additional arguments to pass to the GNN.
"""
# Choose the `target` depending on the explanation type:
target = self.get_target(target, x, edge_index, **kwargs)

prediction: Optional[Tensor] = None
if self.explanation_type == ExplanationType.phenomenon:
if target is None:
raise ValueError(
f"The 'target' has to be provided for the explanation "
f"type '{self.explanation_type.value}'")
elif self.explanation_type == ExplanationType.model:
if target is not None:
warnings.warn(
f"The 'target' should not be provided for the explanation "
f"type '{self.explanation_type.value}'")
prediction = self.get_prediction(x, edge_index, **kwargs)
target = self._get_target(prediction)

training = self.model.training
self.model.eval()
Expand All @@ -171,31 +186,17 @@ def __call__(

self.model.train(training)

# store unprocessed explanation
self.base_explanation = explanation

return self._post_process(explanation)

def get_target(self, target, x, edge_index, **kwargs) -> Tensor:
explanation_type = self.explainer_config.explanation_type

if explanation_type == ExplanationType.phenomenon:
if target is None:
raise ValueError(
f"The 'target' has to be provided for the explanation "
f"type '{explanation_type.value}'")
elif explanation_type == ExplanationType.model:
if target is not None:
warnings.warn(
f"The 'target' should not be provided for the explanation "
f"type '{explanation_type.value}'")
target = self.get_prediction(x=x, edge_index=edge_index, **kwargs)
# Add explainer objectives to the `Explanation` object:
explanation._model_config = self.model_config
explanation.prediction = prediction
explanation.target = target
explanation.index = index
explanation.target_index = target_index

return target

def _post_process(self, explanation: Explanation) -> Explanation:
R"""Post-processes the explanation mask according to the thresholding
method and the user configuration.
# Add model inputs to the `Explanation` object:
if isinstance(explanation, Explanation):
explanation.x = x
explanation.edge_index = edge_index

for key, arg in kwargs.items(): # Add remaining `kwargs`:
explanation[key] = arg
Expand All @@ -209,24 +210,16 @@ def _post_process(self, explanation: Explanation) -> Explanation:
for edge_type, value in edge_index.items():
explanation[edge_type].edge_index = value

# Avoid modification of the original explanation:
explanation = copy.copy(explanation)

mask_dict = { # Get the available masks:
key: explanation[key]
for key in explanation.available_explanations
}

if self.threshold_config.type == ThresholdType.hard:
mask_dict = hard_threshold(mask_dict, self.threshold_config.value)

elif self.threshold_config.type == ThresholdType.topk:
mask_dict = topk_threshold(mask_dict, self.threshold_config.value,
hard=False)

elif self.threshold_config.type == ThresholdType.topk_hard:
mask_dict = topk_threshold(mask_dict, self.threshold_config.value,
hard=True)
for key, arg in kwargs.items(): # Add remaining `kwargs`:
if isinstance(arg, dict):
# Keyword arguments are likely named `{attr_name}_dict`
# while we only want to assign the `{attr_name}` to the
# `HeteroExplanation` object:
key = key[:-5] if key.endswith('_dict') else key
for type_name, value in arg.items():
explanation[type_name][key] = value
else:
explanation[key] = arg

return explanation.threshold(self.threshold_config)

Expand Down
82 changes: 7 additions & 75 deletions torch_geometric/explain/explanation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from typing import List, Optional, Union

from typing import Dict, List, Optional, Union

import torch
from torch import Tensor
Expand All @@ -16,13 +17,15 @@ def available_explanations(self) -> List[str]:
"""Returns the available explanation masks."""
return [key for key in self.keys if key.endswith('_mask')]


def validate_masks(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the :class:`Explanation` masks."""
status = True

for store in self.node_stores:
mask = store.get('node_mask')
if mask is not None and store.num_nodes != mask.size(0):

status = False
warn_or_raise(
f"Expected a 'node_mask' with {store.num_nodes} nodes "
Expand Down Expand Up @@ -142,9 +145,6 @@ class Explanation(Data, ExplanationMixin):
It can also hold the original graph if needed.

Args:
index (Union[int, Tensor], optional): The index of the model
output that the explanation explains. Can be a single
index or a tensor of indices. (default: :obj:`None`)
node_mask (Tensor, optional): Node-level mask with shape
:obj:`[num_nodes]`. (default: :obj:`None`)
edge_mask (Tensor, optional): Edge-level mask with shape
Expand All @@ -155,78 +155,10 @@ class Explanation(Data, ExplanationMixin):
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
**kwargs (optional): Additional attributes.
"""
def __init__(
self,
index: Optional[Union[int, Tensor]] = None,
node_mask: Optional[Tensor] = None,
edge_mask: Optional[Tensor] = None,
node_feat_mask: Optional[Tensor] = None,
edge_feat_mask: Optional[Tensor] = None,
**kwargs,
):
super().__init__(
index=index,
node_mask=node_mask,
edge_mask=edge_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
**kwargs,
)

@property
def available_explanations(self) -> List[str]:
"""Returns the available explanation masks."""
return [
key for key in self.keys
if key.endswith('_mask') and self[key] is not None
]

def validate(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the explanation"""
status = super().validate()

# TODO check that index is in node_mask
if 'node_mask' in self and self.num_nodes != self.node_mask.size(0):
status = False
warn_or_raise(
f"Expected a 'node_mask' with {self.num_nodes} nodes "
f"(got {self.node_mask.size(0)} nodes)", raise_on_error)

if 'edge_mask' in self and self.num_edges != self.edge_mask.size(0):
status = False
warn_or_raise(
f"Expected an 'edge_mask' with {self.num_edges} edges "
f"(got {self.edge_mask.size(0)} edges)", raise_on_error)

if 'node_feat_mask' in self:
if 'x' in self and self.x.size() != self.node_feat_mask.size():
status = False
warn_or_raise(
f"Expected a 'node_feat_mask' of shape "
f"{list(self.x.size())} (got shape "
f"{list(self.node_feat_mask.size())})", raise_on_error)
elif self.num_nodes != self.node_feat_mask.size(0):
status = False
warn_or_raise(
f"Expected a 'node_feat_mask' with {self.num_nodes} nodes "
f"(got {self.node_feat_mask.size(0)} nodes)",
raise_on_error)

if 'edge_feat_mask' in self:
if ('edge_attr' in self
and self.edge_attr.size() != self.edge_feat_mask.size()):
status = False
warn_or_raise(
f"Expected an 'edge_feat_mask' of shape "
f"{list(self.edge_attr.size())} (got shape "
f"{list(self.edge_feat_mask.size())})", raise_on_error)
elif self.num_edges != self.edge_feat_mask.size(0):
status = False
warn_or_raise(
f"Expected an 'edge_feat_mask' with {self.num_edges} "
f"edges (got {self.edge_feat_mask.size(0)} edges)",
raise_on_error)

r"""Validates the correctness of the :class:`Explanation` object."""
status = super().validate(raise_on_error)
status &= self.validate_masks(raise_on_error)
return status

def get_explanation_subgraph(self) -> 'Explanation':
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.