Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: shape '[-1, 8, 1, 1]' is invalid for input of size 15 specifically when using ShapleyValueSampling as an explainer #9901

Open
MohSamNaf opened this issue Dec 28, 2024 · 0 comments

Comments

@MohSamNaf
Copy link

🐛 Describe the bug

Hi,

I cannot use ShapleyValueSampling as a CaptumExplainer. It always gives me this error, while other types of CaptumExplainer works without any problems.

There is no examples showing if ShapleyValueSampling requires a special preprocessing of the data or how it shall be used with the library giving the error RuntimeError: shape '[-1, 8, 1, 1]' is invalid for input of size 15

Toy Example:

from torch_geometric.data import Data
import torch

import torch.nn as nn
from torch_geometric.nn import GCNConv

import torch.optim as optim
import torch.nn.functional as F

from torch_geometric.explain import CaptumExplainer, Explainer

x = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5]], dtype=torch.float)  # 4 nodes

edge_index = torch.tensor([
    [0, 1, 2, 3, 0, 3],  # Source nodes
    [1, 2, 3, 0, 2, 1]   # Target nodes
], dtype=torch.long)

y = torch.tensor([0, 1, 0, 1], dtype=torch.long)

data = Data(x=x, edge_index=edge_index, y=y)


class ToyGCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ToyGCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        return x

input_dim = data.x.shape[1]
hidden_dim = 4
output_dim = 2  # Number of classes

model = ToyGCN(input_dim, hidden_dim, output_dim)

optimizer = optim.Adam(model.parameters(), lr=0.01)

epochs = 20
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out, data.y)
    loss.backward()
    optimizer.step()
    
    print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')

model.eval()
_, pred = model(data.x, data.edge_index).max(dim=1)
print(f'Predictions: {pred}')
print(f'Ground truth: {data.y}')

explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('ShapleyValueSampling'),
    explanation_type='model',
    model_config=dict(
        mode='binary_classification',
        task_level='graph',
        return_type='probs',
    ),
    node_mask_type='attributes',
    edge_mask_type=None,
)


explanation = explainer(
    data.x, data.edge_index)

Traceback:

{
	"name": "RuntimeError",
	"message": "shape '[-1, 8, 1, 1]' is invalid for input of size 15",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 75
     59 print(f'Ground truth: {data.y}')
     61 explainer = Explainer(
     62     model=model,
     63     algorithm=CaptumExplainer('ShapleyValueSampling'),
   (...)
     71     edge_mask_type=None,
     72 )
---> 75 explanation = explainer(
     76     data.x, data.edge_index)

File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/torch_geometric/explain/explainer.py:205, in Explainer.__call__(self, x, edge_index, target, index, **kwargs)
    202 training = self.model.training
    203 self.model.eval()
--> 205 explanation = self.algorithm(
    206     self.model,
    207     x,
    208     edge_index,
    209     target=target,
    210     index=index,
    211     **kwargs,
    212 )
    214 self.model.train(training)
    216 # Add explainer objectives to the `Explanation` object:

File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/torch_geometric/explain/algorithm/captum_explainer.py:170, in CaptumExplainer.forward(self, model, x, edge_index, target, index, **kwargs)
    167 elif index is not None:
    168     target = target[index]
--> 170 attributions = self.attribution_method_instance.attribute(
    171     inputs=inputs,
    172     target=target,
    173     additional_forward_args=add_forward_args,
    174     **self.kwargs,
    175 )
    177 node_mask, edge_mask = convert_captum_output(
    178     attributions,
    179     mask_type,
    180     metadata,
    181 )
    183 if not isinstance(x, dict):

File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
     40 @wraps(func)
     41 def wrapper(*args, **kwargs):
---> 42     return func(*args, **kwargs)

File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/captum/attr/_core/shapley_value.py:411, in ShapleyValueSampling.attribute(self, inputs, baselines, target, additional_forward_args, feature_mask, n_samples, perturbations_per_eval, show_progress)
    403     prev_results = all_eval[-num_examples:]
    405 for j in range(len(total_attrib)):
    406     # format eval_diff to shape
    407     # (n_perturb, *output_shape, 1,.. 1)
    408     # where n_perturb may not be perturb_per_eval
    409     # Append n_input_feature dim of 1 to make the tensor
    410     # have the same dim as the mask tensor.
--> 411     formatted_eval_diff = eval_diff.reshape(
    412         (-1,) + output_shape + (len(inputs[j].shape) - 1) * (1,)
    413     )
    415     # mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
    416     # reshape to
    417     # (
   (...)
    420     #     *broadcastable_to_input_feature_shape
    421     # )
    422     cur_mask = current_masks[j]

RuntimeError: shape '[-1, 8, 1, 1]' is invalid for input of size 15"
}

Versions

Collecting environment information...
PyTorch version: 2.2.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:23:11) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 560.81
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 9 7900X 12-Core Processor
CPU family: 25
Model: 97
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 2
BogoMIPS: 9399.72
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 8 MiB (8 instances)
L3 cache: 32 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP conditional; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] focal-loss-torch==0.1.2
[pip3] numpy==1.23.5
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==8.9.2.26
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.19.3
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] nvtx==0.2.8
[pip3] torch==2.2.0+cu121
[pip3] torch_cluster==1.6.3+pt22cu121
[pip3] torch_geometric==2.5.3
[pip3] torch_scatter==2.1.2+pt22cu121
[pip3] torch_sparse==0.6.18+pt22cu121
[pip3] torch_spline_conv==1.2.2+pt22cu121
[pip3] torchaudio==2.2.0+cu121
[pip3] torchinfo==1.8.0
[pip3] torchvision==0.17.0+cu121
[pip3] triton==2.2.0
[conda] cuda-cudart 11.8.89 0 nvidia
[conda] cuda-cupti 11.8.87 0 nvidia
[conda] cuda-libraries 11.8.0 0 nvidia
[conda] cuda-nvrtc 11.8.89 0 nvidia
[conda] cuda-nvtx 11.8.86 0 nvidia
[conda] cuda-runtime 11.8.0 0 nvidia
[conda] cudatoolkit 11.8.0 h6a678d5_0
[conda] cudnn 8.9.2.26 cuda11_0
[conda] focal-loss-torch 0.1.2 pypi_0 pypi
[conda] libcublas 11.11.3.6 0 nvidia
[conda] libcufft 10.9.0.58 0 nvidia
[conda] libcurand 10.3.3.141 0 nvidia
[conda] libcusolver 11.4.1.48 0 nvidia
[conda] libcusparse 11.7.5.86 0 nvidia
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39h0c7bc48_1 conda-forge
[conda] mkl_random 1.2.2 py39hde0f152_0 conda-forge
[conda] nccl 2.19.4.1 h6103f9b_0 conda-forge
[conda] nomkl 3.0 0
[conda] numpy 1.23.5 py39h3d75532_0 conda-forge
[conda] nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.19.3 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
[conda] nvtx 0.2.8 pypi_0 pypi
[conda] torch 2.2.0+cu121 pypi_0 pypi
[conda] torch-cluster 1.6.3+pt22cu121 pypi_0 pypi
[conda] torch-geometric 2.5.3 pypi_0 pypi
[conda] torch-scatter 2.1.2+pt22cu121 pypi_0 pypi
[conda] torch-sparse 0.6.18+pt22cu121 pypi_0 pypi
[conda] torch-spline-conv 1.2.2+pt22cu121 pypi_0 pypi
[conda] torchaudio 2.2.0+cu121 pypi_0 pypi
[conda] torchinfo 1.8.0 pypi_0 pypi
[conda] torchvision 0.17.0+cu121 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pyp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants