Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explainer Model is incompatible with FiLMConv Layer #5658

Open
fratajcz opened this issue Oct 12, 2022 · 5 comments
Open

Explainer Model is incompatible with FiLMConv Layer #5658

fratajcz opened this issue Oct 12, 2022 · 5 comments

Comments

@fratajcz
Copy link

fratajcz commented Oct 12, 2022

🐛 Describe the bug

Hi!

I use the Explainer that integrates Captum as described in the example as follows:

edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

captum_model = to_captum(model, mask_type='node_and_edge',
                            output_idx=output_idx)

ig = IntegratedGradients(captum_model)

ig_attr_node, ig_attr_edge = ig.attribute(
                (data.x.float().unsqueeze(0), edge_mask.unsqueeze(0)),
                additional_forward_args=(data.edge_index), internal_batch_size=1)

edge_index is a SparseTensor that also contains the information about the edge types (since FiLMConv is for multigraphs).

However, this raises an error because the edge mask is 2-dimensional:

Traceback (most recent call last):
  File "speos/explanation_dummy_film.py", line 108, in <module>
    additional_forward_args=(data.edge_index), internal_batch_size=1)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/log/__init__.py", line 35, in wrapper
    return func(*args, **kwargs)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_core/integrated_gradients.py", line 282, in attribute
    method=method,
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_utils/batching.py", line 79, in _batch_attribution
    **kwargs, n_steps=batch_steps, step_sizes_and_alphas=(step_sizes, alphas)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_core/integrated_gradients.py", line 354, in _attribute
    additional_forward_args=input_additional_args,
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/_utils/gradient.py", line 112, in compute_gradients
    outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/_utils/common.py", line 459, in _run_forward
    else inputs
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch_geometric/nn/models/explainer.py", line 78, in forward
    x = self.model(mask[0], *args)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/tmp/florin.ratajczak_pyg/tmph6keus6p.py", line 24, in forward
    x = self.module_6(x, edge_index)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch_geometric/nn/conv/film_conv.py", line 138, in forward
    x=lin(x[0]), beta=beta, gamma=gamma, size=None)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 335, in propagate
    edge_mask = torch.cat([edge_mask, loop], dim=0)
RuntimeError: Tensors must have same number of dimensions: got 2 and 1 

if I remove the .unsqueeze(0) from the edge mask to get the requested dimension I get an Error from the explainer class:

Traceback (most recent call last):
  File "speos/explanation_dummy_film.py", line 108, in <module>
    additional_forward_args=(data.edge_index), internal_batch_size=1)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/log/__init__.py", line 35, in wrapper
    return func(*args, **kwargs)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_core/integrated_gradients.py", line 282, in attribute
    method=method,
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_utils/batching.py", line 79, in _batch_attribution
    **kwargs, n_steps=batch_steps, step_sizes_and_alphas=(step_sizes, alphas)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_core/integrated_gradients.py", line 354, in _attribute
    additional_forward_args=input_additional_args,
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/_utils/gradient.py", line 112, in compute_gradients
    outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/_utils/common.py", line 459, in _run_forward
    else inputs
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch_geometric/nn/models/explainer.py", line 59, in forward
    assert args[0].shape[0] == 1, "Dimension 0 of input should be 1"
AssertionError: Dimension 0 of input should be 1

The whole thing works with TAGConv and GCNConv layer, so I expect that the culprit is the FiLMConv layer. I will try hacking the layer implementation to see if I can conditionally squeeze the mask.

Environment

  • PyG version: 2.0.4
  • PyTorch version: 1.12.1
  • OS: Ubuntu 20.04
  • Python version: 3.7
  • CUDA/cuDNN version: CPU
  • How you installed PyTorch and PyG (conda, pip, source): pip
  • Any other relevant information (e.g., version of torch-scatter):
@fratajcz fratajcz added the bug label Oct 12, 2022
@fratajcz
Copy link
Author

also happens in the most up to date version in pip, 2.1.0.post1, but then it is an identical line in line 494, in explain_message.

@fratajcz
Copy link
Author

fratajcz commented Oct 12, 2022

So, the lines in question are:

if inputs.size(self.node_dim) != edge_mask.size(0):
    edge_mask = edge_mask[self._loop_mask]
    loop = edge_mask.new_ones(size_i)
    edge_mask = torch.cat([edge_mask, loop], dim=0)
    assert inputs.size(self.node_dim) == edge_mask.size(0)

When I check this, inputs.size(self.node_dim)evaluates to 158962, which is the number of edges of one of the adjacencies I feed into the network. I don't know why it does that. data.num_edges and edge_mask.shape(before unsqueezing gives the correct results of 4268876 and data.num_nodes gives 16852. So where does the 158962 come from?

@fratajcz
Copy link
Author

I think I see where this error is coming from.

My edge_index is built as follows:

SparseTensor(row=tensor([    0,     0,     0,  ..., 16851, 16851, 16851]),
             col=tensor([   54,   721,  5041,  ..., 16561, 16573, 16676]),
             val=tensor([31., 31.,  0.,  ..., 31., 31.,  0.]),
             size=(16852, 16852), nnz=4268876, density=1.50%)

where the val value encodes from which adjacency the edge is coming. As it happens, the adjacency with value 0 has exactly 158962 edges, so it tries to apply the edge_mask for all edges (4268876) to the edges from the first adjacency (158962) and fails. How is edge_mask supposed to be formatted in case we have multiple types of edges?

@fratajcz
Copy link
Author

I have tried overriding the MessagePassing.explain_message() to account for the fact that the edge_mask is passed just once but the edge types are processed iteratively. The edge_mask I pass in is now 2-dimensional, with the second dimension holding the information of the edge type, similar to val in the SparseTensorshown in the comment above. I have added a small if clause that tests if we have a 2-dimensional edge_mask and then plucks out the edge mask for the edge type that is being processed right now. To my luck, MessagePassing handles edge types in an ascending order (from 0 to x where x is the last edge type), so I can just increment the type I am looking for with each iteration.

def explain_message(self, inputs: Tensor, size_i: int) -> Tensor:
        # NOTE Replace this method in custom explainers per message-passing
        # layer to customize how messages shall be explained, e.g., via:
        # conv.explain_message = explain_message.__get__(conv, MessagePassing)
        # see stackoverflow.com: 394770/override-a-method-at-instance-level

        edge_mask = self._edge_mask

        if edge_mask is None:
            raise ValueError(f"Could not find a pre-defined 'edge_mask' as "
                             f"part of {self.__class__.__name__}.")
        
        # BEGIN ADDED CODE

        if len(edge_mask.shape) > 1:
            if not hasattr(self, "current_type"):
                self.current_type = 0

            values, types = torch.tensor_split(edge_mask, 2, dim=1)  # seperate edge_mask and edge_type again
            unique_types = torch.unique(types)
            actual_type = unique_types[self.current_type]
            edge_mask = values[types == actual_type]
            self.current_type += 1
            if actual_type == types.max():
                self.current_type = 0
                
            # END ADDED CODE

        if self._apply_sigmoid:
            edge_mask = edge_mask.sigmoid()

        # Some ops add self-loops to `edge_index`. We need to do the same for
        # `edge_mask` (but do not train these entries).

        if inputs.size(self.node_dim) != edge_mask.size(0):
            edge_mask = edge_mask[self._loop_mask]
            loop = edge_mask.new_ones(size_i)
            edge_mask = torch.cat([edge_mask.squeeze(0), loop], dim=0)
            #print(inputs.size(self.node_dim))
            assert inputs.size(self.node_dim) == edge_mask.size(0)

        size = [1] * inputs.dim()
        size[self.node_dim] = -1
        return inputs * edge_mask.view(size)

This runs fine, but the result edge attributions are nonsense. When visualizing it with explainer.visualize_subgraph(), the most influential edges are actually not part of the query node's subgraph.

@rusty1s
Copy link
Member

rusty1s commented Oct 20, 2022

Really sorry for the late reply, but I appreciate your detailed description. You are right that FilmConv currently fails due to the iterative application of self.propagate. We have some plans to revisit this soon via #5630.

I think the cleanest approach to fix this though would be to get rid of the of the for-loop altogether. This should be doable by the usage of HeteroLinear. Let me know if you want to work on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants