You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I cannot use ShapleyValueSampling as a CaptumExplainer. It always gives me this error, while other types of CaptumExplainer works without any problems.
There is no examples showing if ShapleyValueSampling requires a special preprocessing of the data or how it shall be used with the library giving the error RuntimeError: shape '[-1, 8, 1, 1]' is invalid for input of size 15
{
"name": "RuntimeError",
"message": "shape '[-1, 8, 1, 1]' is invalid for input of size 15",
"stack": "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[8], line 75
59 print(f'Ground truth: {data.y}')
61 explainer = Explainer(
62 model=model,
63 algorithm=CaptumExplainer('ShapleyValueSampling'),
(...)
71 edge_mask_type=None,
72 )
---> 75 explanation = explainer(
76 data.x, data.edge_index)
File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/torch_geometric/explain/explainer.py:205, in Explainer.__call__(self, x, edge_index, target, index, **kwargs)
202 training = self.model.training
203 self.model.eval()
--> 205 explanation = self.algorithm(
206 self.model,
207 x,
208 edge_index,
209 target=target,
210 index=index,
211 **kwargs,
212 )
214 self.model.train(training)
216 # Add explainer objectives to the `Explanation` object:
File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/torch_geometric/explain/algorithm/captum_explainer.py:170, in CaptumExplainer.forward(self, model, x, edge_index, target, index, **kwargs)
167 elif index is not None:
168 target = target[index]
--> 170 attributions = self.attribution_method_instance.attribute(
171 inputs=inputs,
172 target=target,
173 additional_forward_args=add_forward_args,
174 **self.kwargs,
175 )
177 node_mask, edge_mask = convert_captum_output(
178 attributions,
179 mask_type,
180 metadata,
181 )
183 if not isinstance(x, dict):
File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
40 @wraps(func)
41 def wrapper(*args, **kwargs):
---> 42 return func(*args, **kwargs)
File ~/anaconda3/envs/LinTorchGNN/lib/python3.9/site-packages/captum/attr/_core/shapley_value.py:411, in ShapleyValueSampling.attribute(self, inputs, baselines, target, additional_forward_args, feature_mask, n_samples, perturbations_per_eval, show_progress)
403 prev_results = all_eval[-num_examples:]
405 for j in range(len(total_attrib)):
406 # format eval_diff to shape
407 # (n_perturb, *output_shape, 1,.. 1)
408 # where n_perturb may not be perturb_per_eval
409 # Append n_input_feature dim of 1 to make the tensor
410 # have the same dim as the mask tensor.
--> 411 formatted_eval_diff = eval_diff.reshape(
412 (-1,) + output_shape + (len(inputs[j].shape) - 1) * (1,)
413 )
415 # mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
416 # reshape to
417 # (
(...)
420 # *broadcastable_to_input_feature_shape
421 # )
422 cur_mask = current_masks[j]
RuntimeError: shape '[-1, 8, 1, 1]' is invalid for input of size 15"
}
Versions
Collecting environment information...
PyTorch version: 2.2.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:23:11) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 560.81
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 9 7900X 12-Core Processor
CPU family: 25
Model: 97
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 2
BogoMIPS: 9399.72
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 8 MiB (8 instances)
L3 cache: 32 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP conditional; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
🐛 Describe the bug
Hi,
I cannot use
ShapleyValueSampling
as aCaptumExplainer
. It always gives me this error, while other types ofCaptumExplainer
works without any problems.There is no examples showing if
ShapleyValueSampling
requires a special preprocessing of the data or how it shall be used with the library giving the errorRuntimeError: shape '[-1, 8, 1, 1]' is invalid for input of size 15
Toy Example:
Traceback:
Versions
Collecting environment information...
PyTorch version: 2.2.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:23:11) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 560.81
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 9 7900X 12-Core Processor
CPU family: 25
Model: 97
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 2
BogoMIPS: 9399.72
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 8 MiB (8 instances)
L3 cache: 32 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP conditional; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] focal-loss-torch==0.1.2
[pip3] numpy==1.23.5
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==8.9.2.26
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.19.3
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] nvtx==0.2.8
[pip3] torch==2.2.0+cu121
[pip3] torch_cluster==1.6.3+pt22cu121
[pip3] torch_geometric==2.5.3
[pip3] torch_scatter==2.1.2+pt22cu121
[pip3] torch_sparse==0.6.18+pt22cu121
[pip3] torch_spline_conv==1.2.2+pt22cu121
[pip3] torchaudio==2.2.0+cu121
[pip3] torchinfo==1.8.0
[pip3] torchvision==0.17.0+cu121
[pip3] triton==2.2.0
[conda] cuda-cudart 11.8.89 0 nvidia
[conda] cuda-cupti 11.8.87 0 nvidia
[conda] cuda-libraries 11.8.0 0 nvidia
[conda] cuda-nvrtc 11.8.89 0 nvidia
[conda] cuda-nvtx 11.8.86 0 nvidia
[conda] cuda-runtime 11.8.0 0 nvidia
[conda] cudatoolkit 11.8.0 h6a678d5_0
[conda] cudnn 8.9.2.26 cuda11_0
[conda] focal-loss-torch 0.1.2 pypi_0 pypi
[conda] libcublas 11.11.3.6 0 nvidia
[conda] libcufft 10.9.0.58 0 nvidia
[conda] libcurand 10.3.3.141 0 nvidia
[conda] libcusolver 11.4.1.48 0 nvidia
[conda] libcusparse 11.7.5.86 0 nvidia
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39h0c7bc48_1 conda-forge
[conda] mkl_random 1.2.2 py39hde0f152_0 conda-forge
[conda] nccl 2.19.4.1 h6103f9b_0 conda-forge
[conda] nomkl 3.0 0
[conda] numpy 1.23.5 py39h3d75532_0 conda-forge
[conda] nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.19.3 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
[conda] nvtx 0.2.8 pypi_0 pypi
[conda] torch 2.2.0+cu121 pypi_0 pypi
[conda] torch-cluster 1.6.3+pt22cu121 pypi_0 pypi
[conda] torch-geometric 2.5.3 pypi_0 pypi
[conda] torch-scatter 2.1.2+pt22cu121 pypi_0 pypi
[conda] torch-sparse 0.6.18+pt22cu121 pypi_0 pypi
[conda] torch-spline-conv 1.2.2+pt22cu121 pypi_0 pypi
[conda] torchaudio 2.2.0+cu121 pypi_0 pypi
[conda] torchinfo 1.8.0 pypi_0 pypi
[conda] torchvision 0.17.0+cu121 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pyp
The text was updated successfully, but these errors were encountered: