Skip to content

Commit

Permalink
Merge pull request #482 from datamol-io/residual_edge_fix
Browse files Browse the repository at this point in the history
fixed output scaling bug and removed double residual edges (was already implemented in global architecture)
  • Loading branch information
DomInvivo committed Oct 25, 2023
2 parents 2883d88 + ee56497 commit 983bf6c
Showing 1 changed file with 3 additions and 17 deletions.
20 changes: 3 additions & 17 deletions graphium/nn/pyg_layers/gps_pyg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
node_residual: Optional[bool] = True,
edge_residual: Optional[bool] = True,
normalization: Union[str, Callable] = "none",
mpnn_type: str = "pyg:gine",
mpnn_kwargs: Optional[dict] = None,
Expand Down Expand Up @@ -101,9 +100,6 @@ def __init__(
node_residual:
If node residual is used after on the gnn layer output
edge_residual:
If edge residual is used after on the gnn layer output
normalization:
Normalization to use. Choices:
Expand Down Expand Up @@ -175,7 +171,6 @@ def __init__(

# Residual connections
self.node_residual = node_residual
self.edge_residual = edge_residual

self.precision = precision

Expand Down Expand Up @@ -229,7 +224,7 @@ def scale_activations(self, feature: Tensor, scale_factor: Tensor) -> Tensor:
Tensor: The scaled features
"""
scale_factor = torch.tensor(scale_factor).to(feature.device)
feature *= scale_factor.to(dtype=feature.dtype)
feature = feature / scale_factor.to(dtype=feature.dtype)
return feature

def forward(self, batch: Batch) -> Batch:
Expand All @@ -252,25 +247,16 @@ def forward(self, batch: Batch) -> Batch:
if self.mpnn is not None:
batch_out = self.mpnn(batch_out)
h_local = batch_out.feat
e_local = batch_out.edge_feat
if self.dropout_local is not None:
h_local = self.dropout_local(h_local)
# Apply the residual connection for the node features and scale the activations by some value to help reduce activation growth
if self.node_residual:
if self.layer_depth < 1:
h_local = self.residual_add(h_local, feat_in)
h_local *= 1 / self.scale_activations(h_local, self.output_scale)
h_local = self.scale_activations(h_local, self.output_scale)
else:
h_local *= 1 / self.scale_activations(h_local, self.output_scale)
h_local = self.scale_activations(h_local, self.output_scale)
h_local = self.residual_add(h_local, feat_in)
# Apply the residual connection for the edge features and scale the activations by some value to help reduce activation growth
if self.edge_residual and self.use_edges:
if self.layer_depth < 1:
e_local = self.residual_add(e_local, edges_feat_in)
e_local *= 1 / self.scale_activations(e_local, self.output_scale)
else:
e_local *= 1 / self.scale_activations(e_local, self.output_scale)
e_local = self.residual_add(e_local, edges_feat_in)

if self.norm_layer_local is not None:
h_local = self.norm_layer_local(h_local)
Expand Down

0 comments on commit 983bf6c

Please sign in to comment.