Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MessagePassing.explain_message functionality. #4278

Merged
merged 29 commits into from
Apr 8, 2022
Merged

MessagePassing.explain_message functionality. #4278

merged 29 commits into from
Apr 8, 2022

Conversation

fork123aniket
Copy link
Contributor

This PR contains implementation of how to compute layer-wise weights for each edge in order to produce explanations for both node-level and graph-level tasks. Furthermore, this implementation is different from authors' original implementation and is fast and more memory efficient than theirs.

@fork123aniket
Copy link
Contributor Author

Please take a look at functions (message, get_latest_source_embeddings, get_latest_target_embeddings, and get_latest_messages) which must be copied to other layers' implementation in order to make this GraphMask implementation model-agnostic. The significance of each of these functions can be seen on lines 241-243 (updated_gcn.py) and 306-308 (updated_rgcn.py).

@codecov
Copy link

codecov bot commented Mar 16, 2022

Codecov Report

Merging #4278 (9747c8c) into master (6500858) will increase coverage by 0.01%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master    #4278      +/-   ##
==========================================
+ Coverage   82.58%   82.60%   +0.01%     
==========================================
  Files         312      312              
  Lines       16135    16152      +17     
==========================================
+ Hits        13325    13342      +17     
  Misses       2810     2810              
Impacted Files Coverage Δ
torch_geometric/nn/conv/message_passing.py 99.43% <100.00%> (+0.02%) ⬆️
torch_geometric/nn/models/explainer.py 95.86% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 6500858...9747c8c. Read the comment docs.

@fork123aniket
Copy link
Contributor Author

@rusty1s Am just following up with you that did you get a chance to look into the code files pinned with this PR? Am still awaiting any review comments so we can make further progress on how to make this GraphMask Explainer implementation model-agnostic. In the previous comment, I've provided details of which functions to put more focus on.

@rusty1s
Copy link
Member

rusty1s commented Mar 23, 2022

Hi @fork123aniket. Yes, sorry for the delay. I will try to have a review ready by tomorrow. I sadly only have two hands :(

torch_geometric/nn/conv/updated_gcn.py Outdated Show resolved Hide resolved
torch_geometric/nn/conv/updated_gcn.py Outdated Show resolved Hide resolved
torch_geometric/nn/conv/updated_gcn.py Outdated Show resolved Hide resolved
@rusty1s
Copy link
Member

rusty1s commented Mar 25, 2022

Sorry for the delay. I left some ideas on how to make this GNN-agnostic :) Looking forward to your reply.

@fork123aniket
Copy link
Contributor Author

I've just replied to all the comments above. Please look into each of them and let me know your thoughts. Looking forward to hearing from you very soon....... :)

@Padarn
Copy link
Contributor

Padarn commented Apr 1, 2022 via email

@rusty1s
Copy link
Member

rusty1s commented Apr 4, 2022

@fork123aniket Thanks for the update :) Can you apply the modifications directly inside message_passing.py? This makes it easier to track the changes (and we need to do that anyway before merging).

What do message_scale and message_replace represent? Do we really need to pass them as arguments to forward or can we simply set them as instance attributes during explainability? This way, we could simply access them via self.message_replace since explain_message rather than adding them as input arguments to every GNN layer (which again obviously does not scale).

@fork123aniket
Copy link
Contributor Author

@rusty1s Bingo!!! Successfully set those two required parameters (message_scale and message_replacement) as instance attributes only at the time of generating explanations. Now, we need not to worry about adding them to each layer type's forward() function. Have tested the whole proposed explainability technique on datasets:- Cora (Node Classification - GAT, GCN), Enzymes (Graph Classification - GAT, GCN), and AIFB (Node Classification - RGCN) without touching any of the layer types' original implementation which signals the successful achievement towards model-agnostic graph explainability, WDYT??

Furthermore, as mentioned above, since making changes to any GNN layer types' original implementation is not required, hence have taken all the three previously added files (updated_gcn.py, updated_gat.py, and updated_rgcn.py) down from this PR and added all the modifications, made for GraphMask Explainer to work, inside MessagePassing interface of the recently committed message_passing.py file. Now, you'll be easily able to figure out the differences between this committed change and the original PyG message_passing.py file.

Please review the changes and let me know if any further change is still required somewhere.

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates :)

torch_geometric/nn/conv/message_passing.py Show resolved Hide resolved
@@ -130,6 +129,9 @@ def __init__(self, aggr: Optional[Union[str, List[str]]] = "add",

# Support for GNNExplainer.
self._explain = False
self.message_scale = None
self.message_replacement = None
self._gmask_explain = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we need these arguments here. IMO, the GraphMask module should set these attributes (and we solely rely on _explain to call self.explain_message or not).

torch_geometric/nn/conv/message_passing.py Outdated Show resolved Hide resolved
torch_geometric/nn/conv/message_passing.py Outdated Show resolved Hide resolved
torch_geometric/nn/conv/message_passing.py Outdated Show resolved Hide resolved
@fork123aniket
Copy link
Contributor Author

fork123aniket commented Apr 5, 2022

@rusty1s All the above mentioned 5 issues have been addressed in the recently added message_passing.py. Please review the changes.

@rusty1s rusty1s changed the title Edge Masking to generate Explanations MessagePassing.explain_message functionality. Apr 8, 2022
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Cleaned up a bit and added necessary tests :)

@rusty1s rusty1s merged commit a9824b6 into pyg-team:master Apr 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants