Skip to content

Commit

Permalink
Update gps_pyg.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo authored Aug 11, 2023
1 parent 10fe04b commit 71182f4
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion graphium/nn/pyg_layers/gps_pyg.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def forward(self, batch: Batch) -> Batch:
feat_in = feat # for first residual connection

# Local MPNN with edge attributes.
batch_out = self.mpnn(batch.clone())
batch_out = batch.clone()
if self.mpnn is not None:
batch_out = self.mpnn(batch_out)
h_local = batch_out.feat
if self.dropout_local is not None:
h_local = self.dropout_local(h_local)
Expand Down Expand Up @@ -238,6 +240,9 @@ def forward(self, batch: Batch) -> Batch:
def _parse_mpnn_layer(self, mpnn_type, mpnn_kwargs: Dict[str, Any]) -> Optional[Module]:
"""Parse the MPNN layer."""

if mpnn_type is None:
return

mpnn_kwargs = deepcopy(mpnn_kwargs)
if mpnn_kwargs is None:
mpnn_kwargs = {}
Expand Down

0 comments on commit 71182f4

Please sign in to comment.