diff --git a/graphium/nn/pyg_layers/gps_pyg.py b/graphium/nn/pyg_layers/gps_pyg.py index e706bd5b3..bc04c8288 100644 --- a/graphium/nn/pyg_layers/gps_pyg.py +++ b/graphium/nn/pyg_layers/gps_pyg.py @@ -47,7 +47,6 @@ def __init__( activation: Union[Callable, str] = "relu", dropout: float = 0.0, node_residual: Optional[bool] = True, - edge_residual: Optional[bool] = True, normalization: Union[str, Callable] = "none", mpnn_type: str = "pyg:gine", mpnn_kwargs: Optional[dict] = None, @@ -101,9 +100,6 @@ def __init__( node_residual: If node residual is used after on the gnn layer output - edge_residual: - If edge residual is used after on the gnn layer output - normalization: Normalization to use. Choices: @@ -175,7 +171,6 @@ def __init__( # Residual connections self.node_residual = node_residual - self.edge_residual = edge_residual self.precision = precision @@ -229,7 +224,7 @@ def scale_activations(self, feature: Tensor, scale_factor: Tensor) -> Tensor: Tensor: The scaled features """ scale_factor = torch.tensor(scale_factor).to(feature.device) - feature *= scale_factor.to(dtype=feature.dtype) + feature = feature / scale_factor.to(dtype=feature.dtype) return feature def forward(self, batch: Batch) -> Batch: @@ -252,25 +247,16 @@ def forward(self, batch: Batch) -> Batch: if self.mpnn is not None: batch_out = self.mpnn(batch_out) h_local = batch_out.feat - e_local = batch_out.edge_feat if self.dropout_local is not None: h_local = self.dropout_local(h_local) # Apply the residual connection for the node features and scale the activations by some value to help reduce activation growth if self.node_residual: if self.layer_depth < 1: h_local = self.residual_add(h_local, feat_in) - h_local *= 1 / self.scale_activations(h_local, self.output_scale) + h_local = self.scale_activations(h_local, self.output_scale) else: - h_local *= 1 / self.scale_activations(h_local, self.output_scale) + h_local = self.scale_activations(h_local, self.output_scale) h_local = self.residual_add(h_local, feat_in) - # Apply the residual connection for the edge features and scale the activations by some value to help reduce activation growth - if self.edge_residual and self.use_edges: - if self.layer_depth < 1: - e_local = self.residual_add(e_local, edges_feat_in) - e_local *= 1 / self.scale_activations(e_local, self.output_scale) - else: - e_local *= 1 / self.scale_activations(e_local, self.output_scale) - e_local = self.residual_add(e_local, edges_feat_in) if self.norm_layer_local is not None: h_local = self.norm_layer_local(h_local)