-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MessagePassing.explain_message
functionality.
#4278
Conversation
Please take a look at functions (message, get_latest_source_embeddings, get_latest_target_embeddings, and get_latest_messages) which must be copied to other layers' implementation in order to make this GraphMask implementation model-agnostic. The significance of each of these functions can be seen on lines 241-243 (updated_gcn.py) and 306-308 (updated_rgcn.py). |
for more information, see https://pre-commit.ci
Codecov Report
@@ Coverage Diff @@
## master #4278 +/- ##
==========================================
+ Coverage 82.58% 82.60% +0.01%
==========================================
Files 312 312
Lines 16135 16152 +17
==========================================
+ Hits 13325 13342 +17
Misses 2810 2810
Continue to review full report at Codecov.
|
@rusty1s Am just following up with you that did you get a chance to look into the code files pinned with this PR? Am still awaiting any review comments so we can make further progress on how to make this GraphMask Explainer implementation model-agnostic. In the previous comment, I've provided details of which functions to put more focus on. |
Hi @fork123aniket. Yes, sorry for the delay. I will try to have a review ready by tomorrow. I sadly only have two hands :( |
Sorry for the delay. I left some ideas on how to make this GNN-agnostic :) Looking forward to your reply. |
I've just replied to all the comments above. Please look into each of them and let me know your thoughts. Looking forward to hearing from you very soon....... :) |
The 'explain message' interface just seemed a little specific (and
difficult to understand what to do with), and all models using it will need
to implement it anyway.
But just a gut reaction. Not very familiar with the whole proposal.
…On Fri, 1 Apr 2022, 7:30 pm Matthias Fey, ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In torch_geometric/nn/conv/updated_gcn.py
<#4278 (comment)>
:
> + message_scale=message_scale,
+ message_replacement=message_replacement,
+ size=None)
+
+ if self.bias is not None and not False and not self.explain:
+ out += self.bias
+
+ return out
+
+ def message(self, x_i: Tensor, x_j: Tensor, edge_index: Tensor,
+ edge_weight: OptTensor, message_scale,
+ message_replacement) -> Tensor:
+ if self.explain:
+ basis_messages = x_j if edge_weight is None else edge_weight.view(
+ -1, 1) * x_j
+ basis_messages = self.convert(basis_messages)
I think the goal here is to make this procedure GNN agnostic. As such,
when we want to implement this inside a model's message function, we
basically have to do this for every PyG GNN layer. Let me know if I might
be misunderstanding something.
—
Reply to this email directly, view it on GitHub
<#4278 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAGRPNYSZ6F2YXDS6DIIM4DVC3M6XANCNFSM5Q4PMM7A>
.
You are receiving this because you commented.Message ID:
***@***.***>
--
By communicating with Grab Inc and/or its subsidiaries, associate
companies and jointly controlled entities (“Grab Group”), you are deemed to
have consented to the processing of your personal data as set out in the
Privacy Notice which can be viewed at https://grab.com/privacy/
<https://grab.com/privacy/>
This email contains confidential information
and is only for the intended recipient(s). If you are not the intended
recipient(s), please do not disseminate, distribute or copy this email
Please notify Grab Group immediately if you have received this by mistake
and delete this email from your system. Email transmission cannot be
guaranteed to be secure or error-free as any information therein could be
intercepted, corrupted, lost, destroyed, delayed or incomplete, or contain
viruses. Grab Group do not accept liability for any errors or omissions in
the contents of this email arises as a result of email transmission. All
intellectual property rights in this email and attachments therein shall
remain vested in Grab Group, unless otherwise provided by law.
|
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
@fork123aniket Thanks for the update :) Can you apply the modifications directly inside What do |
@rusty1s Bingo!!! Successfully set those two required parameters ( Furthermore, as mentioned above, since making changes to any GNN layer types' original implementation is not required, hence have taken all the three previously added files ( Please review the changes and let me know if any further change is still required somewhere. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates :)
@@ -130,6 +129,9 @@ def __init__(self, aggr: Optional[Union[str, List[str]]] = "add", | |||
|
|||
# Support for GNNExplainer. | |||
self._explain = False | |||
self.message_scale = None | |||
self.message_replacement = None | |||
self._gmask_explain = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why we need these arguments here. IMO, the GraphMask
module should set these attributes (and we solely rely on _explain
to call self.explain_message
or not).
@rusty1s All the above mentioned 5 issues have been addressed in the recently added |
MessagePassing.explain_message
functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Cleaned up a bit and added necessary tests :)
This PR contains implementation of how to compute layer-wise weights for each edge in order to produce explanations for both node-level and graph-level tasks. Furthermore, this implementation is different from authors' original implementation and is fast and more memory efficient than theirs.